diff --git a/mace/benchmark/BUILD b/mace/benchmark/BUILD index 9deb79662b463c7fe4936b6439a48df9389fc3cc..c15e74081f0adf50b533b3f59966f7f0f11c5a83 100644 --- a/mace/benchmark/BUILD +++ b/mace/benchmark/BUILD @@ -15,7 +15,6 @@ cc_library( hdrs = ["statistics.h"], copts = ["-Werror", "-Wextra", "-Wno-missing-field-initializers"], deps = [ - "//mace/kernels", "//mace/utils", ], ) diff --git a/mace/benchmark/statistics.cc b/mace/benchmark/statistics.cc index bc68dd64f5f384f89288c54b4fdb0082214727aa..83e3e5a013a2667c33bbb539a30865d324fe7285 100644 --- a/mace/benchmark/statistics.cc +++ b/mace/benchmark/statistics.cc @@ -16,7 +16,6 @@ #include -#include "mace/kernels/conv_pool_2d_util.h" #include "mace/utils/logging.h" #include "mace/utils/string_util.h" @@ -39,18 +38,18 @@ std::string MetricToString(const Metric metric) { std::string PaddingTypeToString(int padding_type) { std::stringstream stream; - Padding type = static_cast(padding_type); - switch (type) { - case VALID: stream << "VALID"; break; - case SAME: stream << "SAME"; break; - case FULL: stream << "FULL"; break; + switch (padding_type) { + case 0: stream << "VALID"; break; + case 1: stream << "SAME"; break; + case 2: stream << "FULL"; break; default: stream << padding_type; break; } return stream.str(); } -std::string ShapeToString(const std::vector &output_shape) { +std::string ShapeToString( + const std::vector> &output_shape) { if (output_shape.empty()) { return ""; } @@ -58,9 +57,9 @@ std::string ShapeToString(const std::vector &output_shape) { std::stringstream stream; stream << "["; for (size_t i = 0; i < output_shape.size(); ++i) { - size_t dims_size = output_shape[i].dims_size(); + size_t dims_size = output_shape[i].size(); for (size_t j = 0; j < dims_size; ++j) { - stream << output_shape[i].dims(j); + stream << output_shape[i][j]; if (j != dims_size - 1) { stream << ","; } @@ -176,7 +175,7 @@ std::string OpStat::StatByMetric(const Metric metric, } else { tuple.push_back(VectorToString(record.args.paddings)); } - tuple.push_back(VectorToString(record.args.kernels)); + tuple.push_back(VectorToString(record.args.kernels)); tuple.push_back(ShapeToString(record.output_shape)); tuple.push_back(VectorToString(record.args.dilations)); tuple.push_back(record.name); diff --git a/mace/benchmark/statistics.h b/mace/benchmark/statistics.h index b4d69f275aa7b0edd933897b26929d739f6f4c58..50ca901e52df915a485774a3e8b4c13edfa053aa 100644 --- a/mace/benchmark/statistics.h +++ b/mace/benchmark/statistics.h @@ -23,7 +23,7 @@ #include #include -#include "mace/kernels/conv_pool_2d_util.h" +#include "mace/public/mace.h" #include "mace/utils/string_util.h" namespace mace { @@ -142,7 +142,7 @@ class OpStat{ struct Record{ std::string name; std::string type; - std::vector output_shape; + std::vector> output_shape; ConvPoolArgs args; int64_t order; TimeInfo start; diff --git a/mace/core/net.cc b/mace/core/net.cc index 5114d8bc4abebf8eaafa2b20bc3d1b535548a4f3..346ca354eb5eda0c8f4c6690318bacae3be2d425 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -108,9 +108,13 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) { } } + std::vector> output_shapes; + for (auto output_shape : op->debug_def().output_shape()) { + output_shapes.push_back({output_shape.dims().begin(), + output_shape.dims().end()}); + } OperatorStats op_stats = {op->debug_def().name(), op->debug_def().type(), - {op->debug_def().output_shape().begin(), - op->debug_def().output_shape().end()}, + output_shapes, {strides, padding_type, paddings, dilations, kernels}, call_stats}; run_metadata->op_stats.emplace_back(op_stats); diff --git a/mace/public/mace.h b/mace/public/mace.h index bd2e390b96e38f7f6bf8bb16b63c54d3ad1b3e0e..059a79767d5b03fd17a85de3aeb5c96e30f8de3b 100644 --- a/mace/public/mace.h +++ b/mace/public/mace.h @@ -26,7 +26,6 @@ namespace mace { -class OutputShape; class NetDef; enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3 }; @@ -47,7 +46,7 @@ struct ConvPoolArgs { struct OperatorStats { std::string operator_name; std::string type; - std::vector output_shape; + std::vector> output_shape; ConvPoolArgs args; CallStats stats; };