From 4898d3c1fd769807d83df31e67e99c24a8373185 Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 7 May 2018 10:15:59 +0800 Subject: [PATCH] Refactor model benchmark with new format. --- mace/benchmark/BUILD | 10 +- mace/benchmark/benchmark_model.cc | 112 ++++--------- mace/benchmark/statistics.cc | 269 ++++++++++++++++++++++++++++++ mace/benchmark/statistics.h | 159 ++++++++++++++++++ mace/utils/BUILD | 1 + mace/utils/string_util.cc | 84 ++++++++++ mace/utils/string_util.h | 7 + 7 files changed, 555 insertions(+), 87 deletions(-) create mode 100644 mace/benchmark/statistics.cc create mode 100644 mace/benchmark/statistics.h create mode 100644 mace/utils/string_util.cc diff --git a/mace/benchmark/BUILD b/mace/benchmark/BUILD index 50ed42ee..0cc23bb1 100644 --- a/mace/benchmark/BUILD +++ b/mace/benchmark/BUILD @@ -11,14 +11,12 @@ load( licenses(["notice"]) # Apache 2.0 cc_library( - name = "stat_summarizer", - srcs = ["stat_summarizer.cc"], - hdrs = ["stat_summarizer.h"], + name = "statistics", + srcs = ["statistics.cc"], + hdrs = ["statistics.h"], linkstatic = 1, deps = [ - "//mace/core", "//mace/kernels", - "//mace/public", "//mace/utils", ], ) @@ -31,7 +29,7 @@ cc_binary( linkopts = if_openmp_enabled(["-fopenmp"]), linkstatic = 1, deps = [ - ":stat_summarizer", + ":statistics", "//external:gflags_nothreads", "//mace/codegen:generated_models", ], diff --git a/mace/benchmark/benchmark_model.cc b/mace/benchmark/benchmark_model.cc index 00f46ab6..b282af94 100644 --- a/mace/benchmark/benchmark_model.cc +++ b/mace/benchmark/benchmark_model.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include // NOLINT(build/c++11) @@ -23,7 +24,7 @@ #include "mace/public/mace.h" #include "mace/public/mace_runtime.h" #include "mace/utils/logging.h" -#include "mace/benchmark/stat_summarizer.h" +#include "mace/benchmark/statistics.h" namespace mace { namespace MACE_MODEL_TAG { @@ -120,12 +121,12 @@ DeviceType ParseDeviceType(const std::string &device_str) { bool RunInference(MaceEngine *engine, const std::map &input_infos, std::map *output_infos, - StatSummarizer *summarizer, - int64_t *inference_time_us) { + int64_t *inference_time_us, + OpStat *statistician) { MACE_CHECK_NOTNULL(output_infos); RunMetadata run_metadata; RunMetadata *run_metadata_ptr = nullptr; - if (summarizer) { + if (statistician) { run_metadata_ptr = &run_metadata; } @@ -139,39 +140,33 @@ bool RunInference(MaceEngine *engine, } *inference_time_us = end_time - start_time; - if (summarizer != nullptr) { - summarizer->ProcessMetadata(run_metadata); + if (statistician != nullptr) { + statistician->StatMetadata(run_metadata); } return true; } -bool Run(MaceEngine *engine, +bool Run(const std::string &title, + MaceEngine *engine, const std::map &input_infos, std::map *output_infos, - StatSummarizer *summarizer, int num_runs, double max_time_sec, - int64_t sleep_sec, int64_t *total_time_us, - int64_t *actual_num_runs) { + int64_t *actual_num_runs, + OpStat *statistician) { MACE_CHECK_NOTNULL(output_infos); *total_time_us = 0; - LOG(INFO) << "Running benchmark for max " << num_runs << " iterators, max " - << max_time_sec << " seconds " - << (summarizer != nullptr ? "with " : "without ") - << "detailed stat logging, with " << sleep_sec - << "s sleep between inferences"; - - Stat stat; + TimeInfo time_info; bool util_max_time = (num_runs <= 0); for (int i = 0; util_max_time || i < num_runs; ++i) { int64_t inference_time_us = 0; bool s = RunInference(engine, input_infos, output_infos, - summarizer, &inference_time_us); - stat.UpdateStat(inference_time_us); + &inference_time_us, statistician); + time_info.UpdateTime(inference_time_us); (*total_time_us) += inference_time_us; ++(*actual_num_runs); @@ -183,16 +178,13 @@ bool Run(MaceEngine *engine, LOG(INFO) << "Failed on run " << i; return s; } - - if (sleep_sec > 0) { - std::this_thread::sleep_for(std::chrono::seconds(sleep_sec)); - } } - std::stringstream stream; - stat.OutputToStream(&stream); - LOG(INFO) << stream.str(); - + std::stringstream stream(time_info.ToString(title)); + stream << std::endl; + for (std::string line; std::getline(stream, line);) { + LOG(INFO) << line; + } return true; } @@ -206,19 +198,7 @@ DEFINE_string(output_shape, "", "output shape, separated by colon and comma"); DEFINE_string(input_file, "", "input file name"); DEFINE_int32(max_num_runs, 100, "number of runs max"); DEFINE_string(max_time, "10.0", "length to run max"); -DEFINE_string(inference_delay, "-1", "delay between runs in seconds"); -DEFINE_string(inter_benchmark_delay, "-1", - "delay between benchmarks in seconds"); DEFINE_string(benchmark_name, "", "benchmark name"); -DEFINE_bool(show_run_order, true, "whether to list stats by run order"); -DEFINE_int32(run_order_limit, 0, "how many items to show by run order"); -DEFINE_bool(show_time, true, "whether to list stats by time taken"); -DEFINE_int32(time_limit, 10, "how many items to show by time taken"); -DEFINE_bool(show_memory, false, "whether to list stats by memory used"); -DEFINE_int32(memory_limit, 10, "how many items to show by memory used"); -DEFINE_bool(show_type, true, "whether to list stats by op type"); -DEFINE_bool(show_summary, true, "whether to show a summary of the stats"); -DEFINE_bool(show_flops, true, "whether to estimate the model's FLOPs"); DEFINE_int32(warmup_runs, 1, "how many runs to initialize model"); DEFINE_string(model_data_file, "", "model data file name, used when EMBED_MODEL_DATA set to 0"); @@ -246,30 +226,12 @@ int Main(int argc, char **argv) { LOG(INFO) << "output shapes: [" << FLAGS_output_shape << "]"; LOG(INFO) << "Warmup runs: [" << FLAGS_warmup_runs << "]"; LOG(INFO) << "Num runs: [" << FLAGS_max_num_runs << "]"; - LOG(INFO) << "Inter-inference delay (seconds): [" - << FLAGS_inference_delay << "]"; - LOG(INFO) << "Inter-benchmark delay (seconds): [" - << FLAGS_inter_benchmark_delay << "]"; - - const int64_t inter_inference_sleep_seconds = - std::strtol(FLAGS_inference_delay.c_str(), nullptr, 10); - const int64_t inter_benchmark_sleep_seconds = - std::strtol(FLAGS_inter_benchmark_delay.c_str(), nullptr, 10); + LOG(INFO) << "Max run time: [" << FLAGS_max_time << "]"; + const double max_benchmark_time_seconds = std::strtod(FLAGS_max_time.c_str(), nullptr); - std::unique_ptr stats; - - StatSummarizerOptions stats_options; - stats_options.show_run_order = FLAGS_show_run_order; - stats_options.run_order_limit = FLAGS_run_order_limit; - stats_options.show_time = FLAGS_show_time; - stats_options.time_limit = FLAGS_time_limit; - stats_options.show_memory = FLAGS_show_memory; - stats_options.memory_limit = FLAGS_memory_limit; - stats_options.show_type = FLAGS_show_type; - stats_options.show_summary = FLAGS_show_summary; - stats.reset(new StatSummarizer(stats_options)); + std::unique_ptr statistician(new OpStat()); mace::DeviceType device_type = ParseDeviceType(FLAGS_device); @@ -349,50 +311,38 @@ int Main(int argc, char **argv) { mace::MACE_MODEL_TAG::UnloadModelData(model_data); } - LOG(INFO) << "Warm up"; - int64_t warmup_time_us = 0; int64_t num_warmup_runs = 0; if (FLAGS_warmup_runs > 0) { bool status = - Run(engine_ptr.get(), inputs, &outputs, nullptr, + Run("Warm Up", engine_ptr.get(), inputs, &outputs, FLAGS_warmup_runs, -1.0, - inter_inference_sleep_seconds, &warmup_time_us, &num_warmup_runs); + &warmup_time_us, &num_warmup_runs, nullptr); if (!status) { LOG(ERROR) << "Failed at warm up run"; } } - if (inter_benchmark_sleep_seconds > 0) { - std::this_thread::sleep_for( - std::chrono::seconds(inter_benchmark_sleep_seconds)); - } int64_t no_stat_time_us = 0; int64_t no_stat_runs = 0; bool status = - Run(engine_ptr.get(), inputs, &outputs, - nullptr, FLAGS_max_num_runs, max_benchmark_time_seconds, - inter_inference_sleep_seconds, &no_stat_time_us, &no_stat_runs); + Run("Run without statistics", engine_ptr.get(), inputs, &outputs, + FLAGS_max_num_runs, max_benchmark_time_seconds, + &no_stat_time_us, &no_stat_runs, nullptr); if (!status) { LOG(ERROR) << "Failed at normal no-stat run"; } int64_t stat_time_us = 0; int64_t stat_runs = 0; - status = Run(engine_ptr.get(), inputs, &outputs, - stats.get(), FLAGS_max_num_runs, max_benchmark_time_seconds, - inter_inference_sleep_seconds, &stat_time_us, &stat_runs); + status = Run("Run with statistics", engine_ptr.get(), inputs, &outputs, + FLAGS_max_num_runs, max_benchmark_time_seconds, + &stat_time_us, &stat_runs, statistician.get()); if (!status) { LOG(ERROR) << "Failed at normal stat run"; } - LOG(INFO) << "Average inference timings in us: " - << "Warmup: " - << (FLAGS_warmup_runs > 0 ? warmup_time_us / FLAGS_warmup_runs : 0) - << ", " << "no stats: " << no_stat_time_us / no_stat_runs << ", " - << "with stats: " << stat_time_us / stat_runs; - - stats->PrintOperatorStats(); + statistician->PrintStat(); return 0; } diff --git a/mace/benchmark/statistics.cc b/mace/benchmark/statistics.cc new file mode 100644 index 00000000..60ca2fd5 --- /dev/null +++ b/mace/benchmark/statistics.cc @@ -0,0 +1,269 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/benchmark/statistics.h" + +#include + +#include "mace/kernels/conv_pool_2d_util.h" +#include "mace/public/mace_types.h" +#include "mace/utils/logging.h" +#include "mace/utils/string_util.h" + +namespace mace { +namespace benchmark { + +namespace { +std::string MetricToString(const Metric metric) { + switch (metric) { + case NAME: + return "Name"; + case RUN_ORDER: + return "Run Order"; + case COMPUTATION_TIME: + return "Computation Time"; + default: + return ""; + } +} + +std::string PaddingTypeToString(int padding_type) { + std::stringstream stream; + Padding type = static_cast(padding_type); + switch (type) { + case VALID: stream << "VALID"; break; + case SAME: stream << "SAME"; break; + case FULL: stream << "FULL"; break; + default: stream << padding_type; break; + } + + return stream.str(); +} + +std::string ShapeToString(const std::vector &output_shape) { + if (output_shape.empty()) { + return ""; + } + + std::stringstream stream; + stream << "["; + for (int i = 0; i < output_shape.size(); ++i) { + const std::vector &dims = output_shape[i].dims(); + for (int j = 0; j < dims.size(); ++j) { + stream << dims[j]; + if (j != dims.size() - 1) { + stream << ","; + } + } + if (i != output_shape.size() - 1) { + stream << ":"; + } + } + stream << "]"; + + return stream.str(); +} + +template +std::string VectorToString(const std::vector &vec) { + if (vec.empty()) { + return ""; + } + + std::stringstream stream; + stream << "["; + for (int i = 0; i < vec.size(); ++i) { + stream << vec[i]; + if (i != vec.size() - 1) { + stream << ","; + } + } + stream << "]"; + + return stream.str(); +} + +} // namespace + +void OpStat::StatMetadata(const RunMetadata &meta_data) { + if (meta_data.op_stats.empty()) { + LOG(FATAL) << "Op metadata should not be empty"; + } + int64_t order_idx = 0; + int64_t total_time = 0; + + const int64_t first_op_start_time = meta_data.op_stats[0].stats.start_micros; + + for (auto &op_stat : meta_data.op_stats) { + auto result = records_.emplace(op_stat.operator_name, Record()); + Record *record = &(result.first->second); + + if (result.second) { + record->name = op_stat.operator_name; + record->type = op_stat.type; + record->args = op_stat.args; + record->output_shape = op_stat.output_shape; + record->order = order_idx; + order_idx += 1; + } + record->start.UpdateTime(op_stat.stats.start_micros - first_op_start_time); + int64_t run_time = op_stat.stats.end_micros - op_stat.stats.start_micros; + record->rel_end.UpdateTime(run_time); + record->called_times += 1; + total_time += run_time; + } + total_time_.UpdateTime(total_time); +} + +std::string OpStat::StatByMetric(const Metric metric, + const int top_limit) const { + if (records_.empty()) { + return ""; + } + // sort + std::vector records; + for (auto &record : records_) { + records.push_back(record.second); + } + std::sort(records.begin(), records.end(), + [=](const Record &lhs, const Record &rhs) { + if (metric == RUN_ORDER) { + return lhs.order < rhs.order; + } else if (metric == NAME) { + return lhs.name.compare(rhs.name) < 0; + } else { + return lhs.rel_end.avg() > rhs.rel_end.avg(); + } + }); + + // generate string + std::string title = "Sort by " + MetricToString(metric); + const std::vector header = { + "Node Type", "Start", "First", "Avg(ms)", "%", "cdf%", + "Stride", "Pad", "Filter Shape", "Output Shape", "Dilation", "name" + }; + std::vector> data; + int count = top_limit; + if (top_limit <= 0) count = static_cast(records.size()); + + int64_t accumulate_time = 0; + for (int i = 0; i < count; ++i) { + Record &record = records[i]; + accumulate_time += record.rel_end.sum(); + + std::vector tuple; + tuple.push_back(record.type); + tuple.push_back(FloatToString(record.start.avg() / 1000.0f, 3)); + tuple.push_back(FloatToString(record.rel_end.first() / 1000.0f, 3)); + tuple.push_back(FloatToString(record.rel_end.avg() / 1000.0f, 3)); + tuple.push_back( + FloatToString(record.rel_end.sum() * 100.f / total_time_.sum(), 3)); + tuple.push_back( + FloatToString(accumulate_time * 100.f / total_time_.sum(), 3)); + tuple.push_back(VectorToString(record.args.strides)); + if (record.args.padding_type != -1) { + tuple.push_back(PaddingTypeToString(record.args.padding_type)); + } else { + tuple.push_back(VectorToString(record.args.paddings)); + } + tuple.push_back(VectorToString(record.args.kernels)); + tuple.push_back(ShapeToString(record.output_shape)); + tuple.push_back(VectorToString(record.args.dilations)); + tuple.push_back(record.name); + data.emplace_back(tuple); + } + return mace::string_util::StringFormatter::Table(title, header, data); +} + +std::string OpStat::StatByNodeType() const { + if (records_.empty()) { + return ""; + } + const int64_t round = total_time_.round(); + int64_t total_time = 0; + std::map type_time_map; + std::map type_count_map; + std::map type_called_times_map; + std::set node_types_set; + for (auto &record : records_) { + std::string node_type = record.second.type; + node_types_set.insert(node_type); + + type_time_map[node_type] += record.second.rel_end.sum() / round; + total_time += record.second.rel_end.sum() / round; + type_count_map[node_type] += 1; + type_called_times_map[node_type] += record.second.called_times / round; + } + std::vector node_types(node_types_set.begin(), + node_types_set.end()); + std::sort(node_types.begin(), node_types.end(), + [&](const std::string &lhs, const std::string &rhs) { + return type_time_map[lhs] > type_time_map[rhs]; + }); + + std::string title = "Stat by node type"; + const std::vector header = { + "Node Type", "Count", "Avg(ms)", "%", "cdf%", "Called times" + }; + + float cdf = 0.0f; + std::vector> data; + for (auto type : node_types) { + const float avg_time = type_time_map[type] / 1000.0f; + const float percentage = type_time_map[type] * 100.0f / total_time; + cdf += percentage; + + std::vector tuple; + tuple.push_back(type); + tuple.push_back(IntToString(type_count_map[type])); + tuple.push_back(FloatToString(avg_time, 3)); + tuple.push_back(FloatToString(percentage, 3)); + tuple.push_back(FloatToString(cdf, 3)); + tuple.push_back(IntToString(type_called_times_map[type])); + data.emplace_back(tuple); + } + return mace::string_util::StringFormatter::Table(title, header, data); +} + +std::string OpStat::Summary() const { + std::stringstream stream; + if (!records_.empty()) { + stream << total_time_.ToString("Summary") << std::endl; + } + + stream << records_.size() << " ops total." << std::endl; + + return stream.str(); +} + +void OpStat::PrintStat() const { + std::stringstream stream; + if (!records_.empty()) { + // op stat by run order + stream << StatByMetric(Metric::RUN_ORDER, 0) << std::endl; + // top-10 op stat by time + stream << StatByMetric(Metric::COMPUTATION_TIME, 10) << std::endl; + // op stat by node type + stream << StatByNodeType() << std::endl; + } + // Print summary + stream << Summary(); + + for (std::string line; std::getline(stream, line);) { + LOG(INFO) << line; + } +} + +} // namespace benchmark +} // namespace mace diff --git a/mace/benchmark/statistics.h b/mace/benchmark/statistics.h new file mode 100644 index 00000000..056df9f4 --- /dev/null +++ b/mace/benchmark/statistics.h @@ -0,0 +1,159 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_BENCHMARK_STATISTICS_H_ +#define MACE_BENCHMARK_STATISTICS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "mace/kernels/conv_pool_2d_util.h" +#include "mace/utils/string_util.h" + +namespace mace { + +class RunMetadata; + +namespace benchmark { + +template +std::string IntToString(const IntType v) { + std::stringstream stream; + stream << v; + return stream.str(); +} + +template +std::string FloatToString(const FloatType v, const int32_t precision) { + std::stringstream stream; + stream << std::fixed << std::setprecision(precision) << v; + return stream.str(); +} +// microseconds +template +class TimeInfo { + public: + TimeInfo():round_(0), first_(0), curr_(0), + min_(std::numeric_limits::max()), max_(0), + sum_(0), square_sum(0) + {} + + const int64_t round() const { + return round_; + } + + const T first() const { + return first_; + } + + const T sum() const { + return sum_; + } + + const double avg() const { + return round_ == 0 ? std::numeric_limits::quiet_NaN() : + sum_ * 1.0f / round_; + } + + const double std_deviation() const { + if (round_ == 0 || min_ == max_) { + return 0; + } + const double avg_value = avg(); + return std::sqrt(square_sum / round_ - avg_value * avg_value); + } + + void UpdateTime(const T time) { + if (round_ == 0) { + first_ = time; + } + + curr_ = time; + min_ = std::min(min_, time); + max_ = std::max(max_, time); + + sum_ += time; + square_sum += static_cast(time) * time; + round_ += 1; + } + + std::string ToString(const std::string &title) const { + std::vector header = { + "round", "first(ms)", "curr(ms)", + "min(ms)", "max(ms)", + "avg(ms)", "std" + }; + std::vector> data(1); + data[0].push_back(IntToString(round_)); + data[0].push_back(FloatToString(first_ / 1000.0, 3)); + data[0].push_back(FloatToString(curr_ / 1000.0, 3)); + data[0].push_back(FloatToString(min_ / 1000.0, 3)); + data[0].push_back(FloatToString(max_ / 1000.0, 3)); + data[0].push_back(FloatToString(avg() / 1000.0, 3)); + data[0].push_back(FloatToString(std_deviation(), 3)); + return mace::string_util::StringFormatter::Table(title, header, data); + } + + private: + T first_; + T curr_; + T min_; + T max_; + T sum_; + int64_t round_; + double square_sum; +}; + +enum Metric { + NAME, + RUN_ORDER, + COMPUTATION_TIME, +}; + +class OpStat{ + public: + void StatMetadata(const RunMetadata &meta_data); + + void PrintStat() const; + + private: + std::string StatByMetric(const Metric metric, + const int top_limit) const; + std::string StatByNodeType() const; + std::string Summary() const; + + private: + struct Record{ + std::string name; + std::string type; + std::vector output_shape; + ConvPoolArgs args; + int64_t order; + TimeInfo start; + TimeInfo rel_end; + int64_t called_times; + }; + + std::map records_; + TimeInfo total_time_; +}; + +} // namespace benchmark +} // namespace mace +#endif // MACE_BENCHMARK_STATISTICS_H_ diff --git a/mace/utils/BUILD b/mace/utils/BUILD index 85e0647d..57bec3d0 100644 --- a/mace/utils/BUILD +++ b/mace/utils/BUILD @@ -14,6 +14,7 @@ cc_library( srcs = [ "command_line_flags.cc", "logging.cc", + "string_util.cc", ], hdrs = [ "command_line_flags.h", diff --git a/mace/utils/string_util.cc b/mace/utils/string_util.cc new file mode 100644 index 00000000..31def7f1 --- /dev/null +++ b/mace/utils/string_util.cc @@ -0,0 +1,84 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/utils/string_util.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace mace { +namespace string_util { + +std::ostream &FormatRow(std::ostream &stream, int width) { + stream << std::right << std::setw(width); + return stream; +} + +std::string StringFormatter::Table( + const std::string &title, + const std::vector &header, + const std::vector> &data) { + if (header.empty()) return ""; + const size_t column_size = header.size(); + const size_t data_size = data.size(); + std::vector max_column_len(header.size(), 0); + for (size_t col_idx = 0; col_idx < column_size; ++col_idx) { + max_column_len[col_idx] = std::max( + max_column_len[col_idx], static_cast(header[col_idx].size())); + for (size_t data_idx = 0; data_idx < data_size; ++data_idx) { + if (col_idx < data[data_idx].size()) { + max_column_len[col_idx] = std::max( + max_column_len[col_idx], + static_cast(data[data_idx][col_idx].size())); + } + } + } + const size_t row_length = + std::accumulate(max_column_len.begin(), max_column_len.end(), + 0, std::plus()) + + 2 * column_size + column_size + 1; + const std::string dash_line(row_length, '-'); + std::stringstream stream; + stream << dash_line << std::endl; + FormatRow(stream, static_cast(row_length / 2 + title.size() / 2)) + << title << std::endl; + stream << dash_line << std::endl; + // format header + stream << "|"; + for (size_t h_idx = 0; h_idx < column_size; ++h_idx) { + stream << " "; + FormatRow(stream, max_column_len[h_idx]) << header[h_idx]; + stream << " |"; + } + stream << std::endl << dash_line << std::endl; + // format data + for (size_t data_idx = 0; data_idx < data_size; ++data_idx) { + stream << "|"; + for (size_t h_idx = 0; h_idx < column_size; ++h_idx) { + stream << " "; + FormatRow(stream, max_column_len[h_idx]) << data[data_idx][h_idx]; + stream << " |"; + } + stream << std::endl << dash_line << std::endl; + } + return stream.str(); +} + +} // namespace string_util +} // namespace mace diff --git a/mace/utils/string_util.h b/mace/utils/string_util.h index 7727d24c..e95bd902 100644 --- a/mace/utils/string_util.h +++ b/mace/utils/string_util.h @@ -37,6 +37,13 @@ inline void MakeStringInternal(std::stringstream &ss, MakeStringInternal(ss, args...); } +class StringFormatter { + public: + static std::string Table(const std::string &title, + const std::vector &header, + const std::vector> &data); +}; + } // namespace string_util template -- GitLab