提交 3cfe0dd6 编写于 作者: L liuqi

Refactor: remove data_format and add has_data_format flag in Operation

1. Remove data_format flag for ambiguous
2. add has_data_format for distinguish models without data format.
上级 202ea3a6
...@@ -90,6 +90,16 @@ DeviceType ParseDeviceType(const std::string &device_str) { ...@@ -90,6 +90,16 @@ DeviceType ParseDeviceType(const std::string &device_str) {
} }
} }
DataFormat ParseDataFormat(const std::string &data_format_str) {
if (data_format_str == "NHWC") {
return DataFormat::NHWC;
} else if (data_format_str == "NCHW") {
return DataFormat::NCHW;
} else {
return DataFormat::DF_NONE;
}
}
bool RunInference(MaceEngine *engine, bool RunInference(MaceEngine *engine,
const std::map<std::string, mace::MaceTensor> &input_infos, const std::map<std::string, mace::MaceTensor> &input_infos,
std::map<std::string, mace::MaceTensor> *output_infos, std::map<std::string, mace::MaceTensor> *output_infos,
...@@ -168,6 +178,12 @@ DEFINE_string(output_node, "output_node0,output_node1", ...@@ -168,6 +178,12 @@ DEFINE_string(output_node, "output_node0,output_node1",
"output nodes, separated by comma"); "output nodes, separated by comma");
DEFINE_string(input_shape, "", "input shape, separated by colon and comma"); DEFINE_string(input_shape, "", "input shape, separated by colon and comma");
DEFINE_string(output_shape, "", "output shape, separated by colon and comma"); DEFINE_string(output_shape, "", "output shape, separated by colon and comma");
DEFINE_string(input_data_format,
"NHWC",
"input data formats, NONE|NHWC|NCHW");
DEFINE_string(output_data_format,
"NHWC",
"output data formats, NONE|NHWC|NCHW");
DEFINE_string(input_file, "", "input file name"); DEFINE_string(input_file, "", "input file name");
DEFINE_int32(max_num_runs, 100, "max number of runs"); DEFINE_int32(max_num_runs, 100, "max number of runs");
DEFINE_double(max_seconds, 10.0, "max number of seconds to run"); DEFINE_double(max_seconds, 10.0, "max number of seconds to run");
...@@ -233,6 +249,19 @@ int Main(int argc, char **argv) { ...@@ -233,6 +249,19 @@ int Main(int argc, char **argv) {
ParseShape(output_shapes[i], &output_shape_vec[i]); ParseShape(output_shapes[i], &output_shape_vec[i]);
} }
std::vector<std::string> raw_input_data_formats =
str_util::Split(FLAGS_input_data_format, ',');
std::vector<std::string> raw_output_data_formats =
str_util::Split(FLAGS_output_data_format, ',');
std::vector<DataFormat> input_data_formats(input_count);
std::vector<DataFormat> output_data_formats(output_count);
for (size_t i = 0; i < input_count; ++i) {
input_data_formats[i] = ParseDataFormat(raw_input_data_formats[i]);
}
for (size_t i = 0; i < output_count; ++i) {
output_data_formats[i] = ParseDataFormat(raw_output_data_formats[i]);
}
mace::DeviceType device_type = ParseDeviceType(FLAGS_device); mace::DeviceType device_type = ParseDeviceType(FLAGS_device);
// configuration // configuration
...@@ -333,7 +362,8 @@ int Main(int argc, char **argv) { ...@@ -333,7 +362,8 @@ int Main(int argc, char **argv) {
LOG(INFO) << "Open input file failed"; LOG(INFO) << "Open input file failed";
return -1; return -1;
} }
inputs[input_names[i]] = mace::MaceTensor(input_shape_vec[i], buffer_in); inputs[input_names[i]] = mace::MaceTensor(input_shape_vec[i], buffer_in,
input_data_formats[i]);
} }
for (size_t i = 0; i < output_count; ++i) { for (size_t i = 0; i < output_count; ++i) {
...@@ -344,7 +374,8 @@ int Main(int argc, char **argv) { ...@@ -344,7 +374,8 @@ int Main(int argc, char **argv) {
auto buffer_out = std::shared_ptr<float>(new float[output_size], auto buffer_out = std::shared_ptr<float>(new float[output_size],
std::default_delete<float[]>()); std::default_delete<float[]>());
outputs[output_names[i]] = mace::MaceTensor(output_shape_vec[i], outputs[output_names[i]] = mace::MaceTensor(output_shape_vec[i],
buffer_out); buffer_out,
output_data_formats[i]);
} }
int64_t warmup_time_us = 0; int64_t warmup_time_us = 0;
......
...@@ -115,6 +115,8 @@ void MemoryOptimizer::Optimize( ...@@ -115,6 +115,8 @@ void MemoryOptimizer::Optimize(
op_def->output_type_size()); op_def->output_type_size());
DataType dt; DataType dt;
bool has_data_format = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op_def, "has_data_format", 0) != 0;
int output_size = op_def->output_size(); int output_size = op_def->output_size();
for (int i = 0; i < output_size; ++i) { for (int i = 0; i < output_size; ++i) {
if (i < op_def->output_type_size()) { if (i < op_def->output_type_size()) {
...@@ -134,7 +136,7 @@ void MemoryOptimizer::Optimize( ...@@ -134,7 +136,7 @@ void MemoryOptimizer::Optimize(
MemoryBlock best_mem_block; MemoryBlock best_mem_block;
if (IsMemoryReuseOp(op_def->type())) { if (IsMemoryReuseOp(op_def->type())) {
if (tensor_mem_map_.count(op_def->input(0)) == 1) { if (tensor_mem_map_.count(op_def->input(0)) == 1) {
best_mem_id = tensor_mem_map_[op_def->input(0)].first; best_mem_id = tensor_mem_map_.at(op_def->input(0)).mem_id;
} }
} else { } else {
auto shape = std::vector<int64_t>( auto shape = std::vector<int64_t>(
...@@ -204,7 +206,8 @@ void MemoryOptimizer::Optimize( ...@@ -204,7 +206,8 @@ void MemoryOptimizer::Optimize(
} else { } else {
mem_ref_count_[best_mem_id] = 1; mem_ref_count_[best_mem_id] = 1;
} }
tensor_mem_map_[op_def->output(i)] = std::make_pair(best_mem_id, dt); tensor_mem_map_.emplace(op_def->output(i), TensorMemInfo(best_mem_id,
dt, has_data_format));
} }
} }
...@@ -216,7 +219,7 @@ void MemoryOptimizer::Optimize( ...@@ -216,7 +219,7 @@ void MemoryOptimizer::Optimize(
tensor_ref_count_[input_name] -= 1; tensor_ref_count_[input_name] -= 1;
if (tensor_ref_count_.at(input_name) == 0 && if (tensor_ref_count_.at(input_name) == 0 &&
tensor_mem_map_.count(input_name) == 1) { tensor_mem_map_.count(input_name) == 1) {
int mem_id = tensor_mem_map_.at(input_name).first; int mem_id = tensor_mem_map_.at(input_name).mem_id;
mem_ref_count_[mem_id] -= 1; mem_ref_count_[mem_id] -= 1;
if (mem_ref_count_.at(mem_id) == 0) { if (mem_ref_count_.at(mem_id) == 0) {
idle_blocks_.insert(mem_id); idle_blocks_.insert(mem_id);
...@@ -236,7 +239,7 @@ const std::vector<MemoryBlock>& MemoryOptimizer::mem_blocks() const { ...@@ -236,7 +239,7 @@ const std::vector<MemoryBlock>& MemoryOptimizer::mem_blocks() const {
return mem_blocks_; return mem_blocks_;
} }
const std::unordered_map<std::string, std::pair<int, DataType>>& const std::unordered_map<std::string, MemoryOptimizer::TensorMemInfo>&
MemoryOptimizer::tensor_mem_map() const { MemoryOptimizer::tensor_mem_map() const {
return tensor_mem_map_; return tensor_mem_map_;
} }
......
...@@ -77,6 +77,17 @@ class MemoryBlock { ...@@ -77,6 +77,17 @@ class MemoryBlock {
}; };
class MemoryOptimizer { class MemoryOptimizer {
public:
struct TensorMemInfo {
int mem_id;
DataType data_type;
bool has_data_format;
TensorMemInfo(int mem_id, DataType data_type, bool has_data_format) :
mem_id(mem_id), data_type(data_type), has_data_format(has_data_format)
{}
};
public: public:
static bool IsMemoryReuseOp(const std::string &op_type); static bool IsMemoryReuseOp(const std::string &op_type);
void UpdateTensorRef(const std::string &tensor_name); void UpdateTensorRef(const std::string &tensor_name);
...@@ -86,8 +97,7 @@ class MemoryOptimizer { ...@@ -86,8 +97,7 @@ class MemoryOptimizer {
const std::vector<MemoryBlock> &mem_blocks() const; const std::vector<MemoryBlock> &mem_blocks() const;
const std::unordered_map<std::string, const std::unordered_map<std::string, TensorMemInfo> &tensor_mem_map() const;
std::pair<int, DataType>> &tensor_mem_map() const;
std::string DebugInfo() const; std::string DebugInfo() const;
...@@ -101,7 +111,7 @@ class MemoryOptimizer { ...@@ -101,7 +111,7 @@ class MemoryOptimizer {
std::vector<MemoryBlock> mem_blocks_; std::vector<MemoryBlock> mem_blocks_;
// tensor name : <mem_id, data_type> // tensor name : <mem_id, data_type>
// Buffer Memory do not different data type, so store the data type. // Buffer Memory do not different data type, so store the data type.
std::unordered_map<std::string, std::pair<int, DataType>> tensor_mem_map_; std::unordered_map<std::string, TensorMemInfo> tensor_mem_map_;
std::unordered_map<int, int> mem_ref_count_; std::unordered_map<int, int> mem_ref_count_;
std::set<int> idle_blocks_; std::set<int> idle_blocks_;
}; };
......
...@@ -70,7 +70,7 @@ std::unique_ptr<Operation> SerialNet::CreateOperation( ...@@ -70,7 +70,7 @@ std::unique_ptr<Operation> SerialNet::CreateOperation(
const OpRegistryBase *op_registry, const OpRegistryBase *op_registry,
OpConstructContext *construct_context, OpConstructContext *construct_context,
std::shared_ptr<OperatorDef> op_def, std::shared_ptr<OperatorDef> op_def,
DataFormat data_format_flag, bool has_data_format,
bool is_quantize_model) { bool is_quantize_model) {
// Create the Operation // Create the Operation
DeviceType target_device_type = target_device_->device_type(); DeviceType target_device_type = target_device_->device_type();
...@@ -100,8 +100,7 @@ std::unique_ptr<Operation> SerialNet::CreateOperation( ...@@ -100,8 +100,7 @@ std::unique_ptr<Operation> SerialNet::CreateOperation(
if (!is_quantize_model && device_type == DeviceType::CPU && if (!is_quantize_model && device_type == DeviceType::CPU &&
op_def->output_shape_size() == op_def->output_size()) { op_def->output_shape_size() == op_def->output_size()) {
for (int out_idx = 0; out_idx < op_def->output_size(); ++out_idx) { for (int out_idx = 0; out_idx < op_def->output_size(); ++out_idx) {
if (data_format_flag == NHWC && if (has_data_format && op_def->output_shape(out_idx).dims_size() == 4) {
op_def->output_shape(out_idx).dims_size() == 4) {
// NHWC -> NCHW // NHWC -> NCHW
std::vector<index_t> output_shape = std::vector<index_t> output_shape =
TransposeShape<index_t, index_t>( TransposeShape<index_t, index_t>(
...@@ -160,7 +159,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -160,7 +159,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
tensor.dims().end())); tensor.dims().end()));
} }
DataFormat data_format_flag = NHWC; bool has_data_format = false;
if (target_device_->device_type() == DeviceType::CPU) { if (target_device_->device_type() == DeviceType::CPU) {
target_mem_type = MemoryType::CPU_BUFFER; target_mem_type = MemoryType::CPU_BUFFER;
for (auto &input_info : net_def->input_info()) { for (auto &input_info : net_def->input_info()) {
...@@ -170,15 +169,15 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -170,15 +169,15 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
// update tensor shape map // update tensor shape map
tensor_shape_map[input_info.name()] = input_shape; tensor_shape_map[input_info.name()] = input_shape;
// Only could be NONE or NHWC // Only could be NONE or NHWC
auto input_data_format = static_cast<DataFormat>( DataFormat input_data_format = static_cast<DataFormat>(
input_info.data_format()); input_info.data_format());
if (!is_quantize_model && input_data_format == NHWC && has_data_format = has_data_format ||
(input_data_format != DataFormat::DF_NONE);
if (!is_quantize_model && input_data_format == DataFormat::NHWC &&
input_info.dims_size() == 4) { input_info.dims_size() == 4) {
// NHWC -> NCHW // NHWC -> NCHW
input_shape = input_shape =
TransposeShape<index_t, index_t>(input_shape, {0, 3, 1, 2}); TransposeShape<index_t, index_t>(input_shape, {0, 3, 1, 2});
} else if (input_data_format == DataFormat::DF_NONE) {
data_format_flag = DataFormat::DF_NONE;
} }
output_map.emplace(input_info.name(), InternalOutputInfo( output_map.emplace(input_info.name(), InternalOutputInfo(
target_mem_type, DataType::DT_FLOAT, input_shape, -1)); target_mem_type, DataType::DT_FLOAT, input_shape, -1));
...@@ -189,11 +188,8 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -189,11 +188,8 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
else { // GPU NOLINT[readability/braces] else { // GPU NOLINT[readability/braces]
target_mem_type = MemoryType::GPU_BUFFER; target_mem_type = MemoryType::GPU_BUFFER;
for (auto &input_info : net_def->input_info()) { for (auto &input_info : net_def->input_info()) {
auto input_data_format = static_cast<DataFormat>( has_data_format = static_cast<DataFormat>(
input_info.data_format()); input_info.data_format()) == NHWC;
if (input_data_format == DataFormat::DF_NONE) {
data_format_flag = DataFormat::DF_NONE;
}
std::vector<index_t> input_shape = std::vector<index_t> input_shape =
std::vector<index_t>(input_info.dims().begin(), std::vector<index_t>(input_info.dims().begin(),
input_info.dims().end()); input_info.dims().end());
...@@ -212,7 +208,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -212,7 +208,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
auto op = CreateOperation(op_registry, auto op = CreateOperation(op_registry,
&construct_context, &construct_context,
op_def, op_def,
data_format_flag, has_data_format,
is_quantize_model); is_quantize_model);
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
// Add input transform operation if necessary // Add input transform operation if necessary
...@@ -259,13 +255,13 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -259,13 +255,13 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
} }
auto transform_op_def = OpenCLUtil::CreateTransformOpDef( auto transform_op_def = OpenCLUtil::CreateTransformOpDef(
input_name, input_shape, t_input_name, input_name, input_shape, t_input_name,
wanted_in_dt, wanted_in_mem_type, data_format_flag); wanted_in_dt, wanted_in_mem_type, has_data_format);
OpConstructContext t_construct_context(ws_); OpConstructContext t_construct_context(ws_);
auto transform_op = CreateOperation( auto transform_op = CreateOperation(
op_registry, op_registry,
&t_construct_context, &t_construct_context,
transform_op_def, transform_op_def,
data_format_flag); has_data_format);
operators_.emplace_back(std::move(transform_op)); operators_.emplace_back(std::move(transform_op));
transformed_set.insert(t_input_name); transformed_set.insert(t_input_name);
output_mem_map[t_input_name] = wanted_in_mem_type; output_mem_map[t_input_name] = wanted_in_mem_type;
...@@ -340,7 +336,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -340,7 +336,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
output_mem_map[output_info.name()] = target_mem_type; output_mem_map[output_info.name()] = target_mem_type;
} }
} }
auto output_data_format = bool output_has_data_format =
static_cast<DataFormat>(output_info.data_format()); static_cast<DataFormat>(output_info.data_format());
auto transform_op_def = OpenCLUtil::CreateTransformOpDef( auto transform_op_def = OpenCLUtil::CreateTransformOpDef(
t_output_name, t_output_name,
...@@ -348,12 +344,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -348,12 +344,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
output_info.name(), output_info.name(),
output_info.data_type(), output_info.data_type(),
target_mem_type, target_mem_type,
data_format_flag); output_has_data_format);
auto transform_op = CreateOperation( auto transform_op = CreateOperation(
op_registry, op_registry,
&construct_context, &construct_context,
transform_op_def, transform_op_def,
output_data_format); output_has_data_format);
operators_.emplace_back(std::move(transform_op)); operators_.emplace_back(std::move(transform_op));
// where to do graph reference count. // where to do graph reference count.
mem_optimizer->UpdateTensorRef(transform_op_def.get()); mem_optimizer->UpdateTensorRef(transform_op_def.get());
......
...@@ -59,7 +59,7 @@ class SerialNet : public NetBase { ...@@ -59,7 +59,7 @@ class SerialNet : public NetBase {
const OpRegistryBase *op_registry, const OpRegistryBase *op_registry,
OpConstructContext *construct_context, OpConstructContext *construct_context,
std::shared_ptr<OperatorDef> op_def, std::shared_ptr<OperatorDef> op_def,
DataFormat input_format, bool has_data_format,
bool is_quantize_model = false); bool is_quantize_model = false);
protected: protected:
......
...@@ -152,7 +152,7 @@ std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef( ...@@ -152,7 +152,7 @@ std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef(
const std::string &output_name, const std::string &output_name,
const mace::DataType dt, const mace::DataType dt,
const mace::MemoryType mem_type, const mace::MemoryType mem_type,
const DataFormat data_format) { bool has_data_format) {
std::unique_ptr<OperatorDef> op(new OperatorDef); std::unique_ptr<OperatorDef> op(new OperatorDef);
std::string op_name = "mace_node_" + output_name; std::string op_name = "mace_node_" + output_name;
op->set_name(op_name); op->set_name(op_name);
...@@ -169,8 +169,8 @@ std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef( ...@@ -169,8 +169,8 @@ std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef(
arg->set_name("T"); arg->set_name("T");
arg->set_i(static_cast<int32_t>(dt)); arg->set_i(static_cast<int32_t>(dt));
arg = op->add_arg(); arg = op->add_arg();
arg->set_name("data_format"); arg->set_name("has_data_format");
arg->set_i(data_format); arg->set_i(has_data_format);
if (!input_shape.empty()) { if (!input_shape.empty()) {
OutputShape *shape = op->add_output_shape(); OutputShape *shape = op->add_output_shape();
for (auto value : input_shape) { for (auto value : input_shape) {
......
...@@ -49,7 +49,7 @@ class OpenCLUtil { ...@@ -49,7 +49,7 @@ class OpenCLUtil {
const std::string &output_name, const std::string &output_name,
const mace::DataType dt, const mace::DataType dt,
const MemoryType mem_type, const MemoryType mem_type,
const DataFormat data_format); bool has_data_format);
}; };
} // namespace mace } // namespace mace
......
...@@ -264,31 +264,35 @@ MaceStatus Workspace::PreallocateOutputTensor( ...@@ -264,31 +264,35 @@ MaceStatus Workspace::PreallocateOutputTensor(
bool is_quantize_model = IsQuantizedModel(net_def); bool is_quantize_model = IsQuantizedModel(net_def);
for (auto &tensor_mem : mem_optimizer->tensor_mem_map()) { for (auto &tensor_mem : mem_optimizer->tensor_mem_map()) {
std::unique_ptr<Tensor> tensor std::unique_ptr<Tensor> tensor
(new Tensor(preallocated_allocator_.GetBuffer(tensor_mem.second.first), (new Tensor(preallocated_allocator_.GetBuffer(tensor_mem.second.mem_id),
tensor_mem.second.second, tensor_mem.second.data_type,
false, tensor_mem.first)); false, tensor_mem.first));
if (mem_blocks[tensor_mem.second.first].mem_type() if (tensor_mem.second.has_data_format) {
== MemoryType::GPU_IMAGE) { if (mem_blocks[tensor_mem.second.mem_id].mem_type()
VLOG(1) << "Tensor: " << tensor_mem.first == MemoryType::GPU_IMAGE) {
<< " Mem: " << tensor_mem.second.first VLOG(1) << "Tensor: " << tensor_mem.first
<< " Data type: " << tensor->dtype() << " Mem: " << tensor_mem.second.mem_id
<< " Image shape: " << " Data type: " << tensor->dtype()
<< tensor->UnderlyingBuffer()->shape()[0] << " Image shape: "
<< ", " << tensor->UnderlyingBuffer()->shape()[0]
<< tensor->UnderlyingBuffer()->shape()[1]; << ", "
tensor->set_data_format(DataFormat::NHWC); << tensor->UnderlyingBuffer()->shape()[1];
} else {
VLOG(1) << "Tensor: " << tensor_mem.first
<< " Mem: " << tensor_mem.second.first
<< " Data type: " << tensor->dtype()
<< ", Buffer size: " << tensor->UnderlyingBuffer()->size();
if (mem_blocks[tensor_mem.second.first].mem_type()
== MemoryType::GPU_BUFFER ||
is_quantize_model) {
tensor->set_data_format(DataFormat::NHWC); tensor->set_data_format(DataFormat::NHWC);
} else { } else {
tensor->set_data_format(DataFormat::NCHW); VLOG(1) << "Tensor: " << tensor_mem.first
<< " Mem: " << tensor_mem.second.mem_id
<< " Data type: " << tensor->dtype()
<< ", Buffer size: " << tensor->UnderlyingBuffer()->size();
if (mem_blocks[tensor_mem.second.mem_id].mem_type()
== MemoryType::GPU_BUFFER ||
is_quantize_model) {
tensor->set_data_format(DataFormat::NHWC);
} else {
tensor->set_data_format(DataFormat::NCHW);
}
} }
} else {
tensor->set_data_format(DataFormat::DF_NONE);
} }
tensor_map_[tensor_mem.first] = std::move(tensor); tensor_map_[tensor_mem.first] = std::move(tensor);
} }
......
...@@ -170,6 +170,15 @@ DeviceType ParseDeviceType(const std::string &device_str) { ...@@ -170,6 +170,15 @@ DeviceType ParseDeviceType(const std::string &device_str) {
} }
} }
DataFormat ParseDataFormat(const std::string &data_format_str) {
if (data_format_str == "NHWC") {
return DataFormat::NHWC;
} else if (data_format_str == "NCHW") {
return DataFormat::NCHW;
} else {
return DataFormat::DF_NONE;
}
}
DEFINE_string(model_name, DEFINE_string(model_name,
"", "",
...@@ -186,6 +195,12 @@ DEFINE_string(output_node, ...@@ -186,6 +195,12 @@ DEFINE_string(output_node,
DEFINE_string(output_shape, DEFINE_string(output_shape,
"1,224,224,2:1,1,1,10", "1,224,224,2:1,1,1,10",
"output shapes, separated by colon and comma"); "output shapes, separated by colon and comma");
DEFINE_string(input_data_format,
"NHWC",
"input data formats, NONE|NHWC|NCHW");
DEFINE_string(output_data_format,
"NHWC",
"output data formats, NONE|NHWC|NCHW");
DEFINE_string(input_file, DEFINE_string(input_file,
"", "",
"input file name | input file prefix for multiple inputs."); "input file name | input file prefix for multiple inputs.");
...@@ -222,8 +237,10 @@ DEFINE_int32(cpu_affinity_policy, 1, ...@@ -222,8 +237,10 @@ DEFINE_int32(cpu_affinity_policy, 1,
bool RunModel(const std::vector<std::string> &input_names, bool RunModel(const std::vector<std::string> &input_names,
const std::vector<std::vector<int64_t>> &input_shapes, const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<DataFormat> &input_data_formats,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names,
const std::vector<std::vector<int64_t>> &output_shapes) { const std::vector<std::vector<int64_t>> &output_shapes,
const std::vector<DataFormat> &output_data_formats) {
// load model // load model
DeviceType device_type = ParseDeviceType(FLAGS_device); DeviceType device_type = ParseDeviceType(FLAGS_device);
// configuration // configuration
...@@ -324,7 +341,8 @@ bool RunModel(const std::vector<std::string> &input_names, ...@@ -324,7 +341,8 @@ bool RunModel(const std::vector<std::string> &input_names,
inputs_size[input_names[i]] = input_size; inputs_size[input_names[i]] = input_size;
auto buffer_in = std::shared_ptr<float>(new float[input_size], auto buffer_in = std::shared_ptr<float>(new float[input_size],
std::default_delete<float[]>()); std::default_delete<float[]>());
inputs[input_names[i]] = mace::MaceTensor(input_shapes[i], buffer_in); inputs[input_names[i]] = mace::MaceTensor(input_shapes[i], buffer_in,
input_data_formats[i]);
} }
for (size_t i = 0; i < output_count; ++i) { for (size_t i = 0; i < output_count; ++i) {
...@@ -333,7 +351,8 @@ bool RunModel(const std::vector<std::string> &input_names, ...@@ -333,7 +351,8 @@ bool RunModel(const std::vector<std::string> &input_names,
std::multiplies<int64_t>()); std::multiplies<int64_t>());
auto buffer_out = std::shared_ptr<float>(new float[output_size], auto buffer_out = std::shared_ptr<float>(new float[output_size],
std::default_delete<float[]>()); std::default_delete<float[]>());
outputs[output_names[i]] = mace::MaceTensor(output_shapes[i], buffer_out); outputs[output_names[i]] = mace::MaceTensor(output_shapes[i], buffer_out,
output_data_formats[i]);
} }
if (!FLAGS_input_dir.empty()) { if (!FLAGS_input_dir.empty()) {
...@@ -485,11 +504,25 @@ int Main(int argc, char **argv) { ...@@ -485,11 +504,25 @@ int Main(int argc, char **argv) {
ParseShape(output_shapes[i], &output_shape_vec[i]); ParseShape(output_shapes[i], &output_shape_vec[i]);
} }
std::vector<std::string> raw_input_data_formats =
str_util::Split(FLAGS_input_data_format, ',');
std::vector<std::string> raw_output_data_formats =
str_util::Split(FLAGS_output_data_format, ',');
std::vector<DataFormat> input_data_formats(input_count);
std::vector<DataFormat> output_data_formats(output_count);
for (size_t i = 0; i < input_count; ++i) {
input_data_formats[i] = ParseDataFormat(raw_input_data_formats[i]);
}
for (size_t i = 0; i < output_count; ++i) {
output_data_formats[i] = ParseDataFormat(raw_output_data_formats[i]);
}
bool ret = false; bool ret = false;
for (int i = 0; i < FLAGS_restart_round; ++i) { for (int i = 0; i < FLAGS_restart_round; ++i) {
std::cout << "restart round " << i << std::endl; std::cout << "restart round " << i << std::endl;
ret = ret =
RunModel(input_names, input_shape_vec, output_names, output_shape_vec); RunModel(input_names, input_shape_vec, input_data_formats,
output_names, output_shape_vec, output_data_formats);
} }
if (ret) { if (ret) {
return 0; return 0;
......
...@@ -143,6 +143,7 @@ void BMNet::SetUp() { ...@@ -143,6 +143,7 @@ void BMNet::SetUp() {
// Add input and output information // Add input and output information
for (size_t i = 0; i < input_names_.size(); ++i) { for (size_t i = 0; i < input_names_.size(); ++i) {
InputInfo *info = net_.add_input_info(); InputInfo *info = net_.add_input_info();
info->set_data_format(DataFormat::NHWC);
info->set_name(input_names_[i]); info->set_name(input_names_[i]);
for (auto d : input_shapes_[i]) { for (auto d : input_shapes_[i]) {
info->add_dims(static_cast<int>(d)); info->add_dims(static_cast<int>(d));
...@@ -243,8 +244,8 @@ void BMNet::AddConv(const std::string &conv_type, ...@@ -243,8 +244,8 @@ void BMNet::AddConv(const std::string &conv_type,
op_def->add_output(output_name); op_def->add_output(output_name);
AddIntsArg(op_def, "strides", strides); AddIntsArg(op_def, "strides", strides);
AddIntArg(op_def, "padding", padding_type); AddIntArg(op_def, "padding", padding_type);
AddIntArg(op_def, "has_data_format", 1);
AddIntArg(op_def, "T", DT_HALF); AddIntArg(op_def, "T", DT_HALF);
AddIntArg(op_def, "data_format", 1);
if (has_relu6) { if (has_relu6) {
AddStringArg(op_def, "activation", "RELUX"); AddStringArg(op_def, "activation", "RELUX");
AddFloatArg(op_def, "max_limit", 6); AddFloatArg(op_def, "max_limit", 6);
...@@ -270,7 +271,7 @@ void BMNet::AddEltwise(const std::string &op_name, ...@@ -270,7 +271,7 @@ void BMNet::AddEltwise(const std::string &op_name,
op_def->add_output(output); op_def->add_output(output);
AddIntArg(op_def, "type", type); AddIntArg(op_def, "type", type);
AddIntArg(op_def, "T", DT_HALF); AddIntArg(op_def, "T", DT_HALF);
AddIntArg(op_def, "data_format", 1); AddIntArg(op_def, "has_data_format", 1);
OutputShape *shape = op_def->add_output_shape(); OutputShape *shape = op_def->add_output_shape();
for (auto dim : output_shape) { for (auto dim : output_shape) {
shape->add_dims(dim); shape->add_dims(dim);
......
...@@ -470,6 +470,9 @@ MaceStatus MaceEngine::Impl::Init( ...@@ -470,6 +470,9 @@ MaceStatus MaceEngine::Impl::Init(
shape[i] = input_info_map_[input_name].dims(i); shape[i] = input_info_map_[input_name].dims(i);
} }
input_tensor->Resize(shape); input_tensor->Resize(shape);
// Set to the default data format
input_tensor->set_data_format(static_cast<DataFormat>(
input_info_map_[input_name].data_format()));
} }
for (auto output_name : output_nodes) { for (auto output_name : output_nodes) {
if (output_info_map_.find(output_name) == output_info_map_.end()) { if (output_info_map_.find(output_name) == output_info_map_.end()) {
...@@ -477,7 +480,9 @@ MaceStatus MaceEngine::Impl::Init( ...@@ -477,7 +480,9 @@ MaceStatus MaceEngine::Impl::Init(
<< "' does not belong to model's outputs " << "' does not belong to model's outputs "
<< MakeString(MapKeys(output_info_map_)); << MakeString(MapKeys(output_info_map_));
} }
#ifdef MACE_ENABLE_HEXAGON
ws_->CreateTensor(output_name, device_->allocator(), DT_FLOAT); ws_->CreateTensor(output_name, device_->allocator(), DT_FLOAT);
#endif
} }
#ifdef MACE_ENABLE_HEXAGON #ifdef MACE_ENABLE_HEXAGON
if (device_type_ == HEXAGON) { if (device_type_ == HEXAGON) {
...@@ -559,47 +564,51 @@ MaceEngine::Impl::~Impl() { ...@@ -559,47 +564,51 @@ MaceEngine::Impl::~Impl() {
MaceStatus MaceEngine::Impl::TransposeInput( MaceStatus MaceEngine::Impl::TransposeInput(
const std::pair<const std::string, MaceTensor> &input, const std::pair<const std::string, MaceTensor> &input,
Tensor *input_tensor) { Tensor *input_tensor) {
if (device_->device_type() == DeviceType::CPU && bool has_data_format = input_tensor->data_format() != DataFormat::DF_NONE;
input.second.shape().size() == 4 && DataFormat data_format = DataFormat::DF_NONE;
input.second.data_format() == NHWC && if (has_data_format) {
!is_quantized_model_) { if (device_->device_type() == DeviceType::CPU &&
VLOG(1) << "Transform input " << input.first << " from NHWC to NCHW"; input.second.shape().size() == 4 &&
input_tensor->set_data_format(DataFormat::NCHW); input.second.data_format() == NHWC &&
std::vector<int> dst_dims = {0, 3, 1, 2}; !is_quantized_model_) {
std::vector<index_t> output_shape = VLOG(1) << "Transform input " << input.first << " from NHWC to NCHW";
TransposeShape<int64_t, index_t>(input.second.shape(), dst_dims); input_tensor->set_data_format(DataFormat::NCHW);
MACE_RETURN_IF_ERROR(input_tensor->Resize(output_shape)); std::vector<int> dst_dims = {0, 3, 1, 2};
Tensor::MappingGuard input_guard(input_tensor); std::vector<index_t> output_shape =
float *input_data = input_tensor->mutable_data<float>(); TransposeShape<int64_t, index_t>(input.second.shape(), dst_dims);
return ops::Transpose(input.second.data().get(), MACE_RETURN_IF_ERROR(input_tensor->Resize(output_shape));
input.second.shape(), Tensor::MappingGuard input_guard(input_tensor);
dst_dims, float *input_data = input_tensor->mutable_data<float>();
input_data); return ops::Transpose(input.second.data().get(),
} else if ( input.second.shape(),
(is_quantized_model_ || device_->device_type() == DeviceType::GPU) && dst_dims,
input.second.shape().size() == 4 && input_data);
input.second.data_format() == DataFormat::NCHW) { } else if (
VLOG(1) << "Transform input " << input.first << " from NCHW to NHWC"; (is_quantized_model_ || device_->device_type() == DeviceType::GPU) &&
std::vector<int> dst_dims = {0, 2, 3, 1}; input.second.shape().size() == 4 &&
input_tensor->set_data_format(DataFormat::NHWC); input.second.data_format() == DataFormat::NCHW) {
std::vector<index_t> output_shape = VLOG(1) << "Transform input " << input.first << " from NCHW to NHWC";
TransposeShape<int64_t, index_t>(input.second.shape(), dst_dims); std::vector<int> dst_dims = {0, 2, 3, 1};
MACE_RETURN_IF_ERROR(input_tensor->Resize(output_shape)); input_tensor->set_data_format(DataFormat::NHWC);
Tensor::MappingGuard input_guard(input_tensor); std::vector<index_t> output_shape =
float *input_data = input_tensor->mutable_data<float>(); TransposeShape<int64_t, index_t>(input.second.shape(), dst_dims);
return ops::Transpose(input.second.data().get(), MACE_RETURN_IF_ERROR(input_tensor->Resize(output_shape));
input.second.shape(), Tensor::MappingGuard input_guard(input_tensor);
dst_dims, float *input_data = input_tensor->mutable_data<float>();
input_data); return ops::Transpose(input.second.data().get(),
} else { input.second.shape(),
input_tensor->set_data_format(input.second.data_format()); dst_dims,
MACE_RETURN_IF_ERROR(input_tensor->Resize(input.second.shape())); input_data);
Tensor::MappingGuard input_guard(input_tensor); }
float *input_data = input_tensor->mutable_data<float>(); data_format = input.second.data_format();
memcpy(input_data, input.second.data().get(),
input_tensor->size() * sizeof(float));
return MaceStatus::MACE_SUCCESS;
} }
input_tensor->set_data_format(data_format);
MACE_RETURN_IF_ERROR(input_tensor->Resize(input.second.shape()));
Tensor::MappingGuard input_guard(input_tensor);
float *input_data = input_tensor->mutable_data<float>();
memcpy(input_data, input.second.data().get(),
input_tensor->size() * sizeof(float));
return MaceStatus::MACE_SUCCESS;
} }
MaceStatus MaceEngine::Impl::TransposeOutput( MaceStatus MaceEngine::Impl::TransposeOutput(
...@@ -607,38 +616,28 @@ MaceStatus MaceEngine::Impl::TransposeOutput( ...@@ -607,38 +616,28 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
std::pair<const std::string, mace::MaceTensor> *output) { std::pair<const std::string, mace::MaceTensor> *output) {
// save output // save output
if (output_tensor != nullptr && output->second.data() != nullptr) { if (output_tensor != nullptr && output->second.data() != nullptr) {
if (device_->device_type() == DeviceType::CPU && if (output_tensor->data_format() != DataFormat::DF_NONE &&
output->second.shape().size() == 4 && output->second.data_format() != DataFormat::DF_NONE &&
output->second.data_format() != output_tensor->data_format()) {
MACE_CHECK(output_tensor->data_format() == NCHW);
VLOG(1) << "Transform output " << output->first << " from NCHW to NHWC";
std::vector<int> dst_dims = {0, 2, 3, 1};
std::vector<index_t> shape =
TransposeShape<index_t, index_t>(output_tensor->shape(),
dst_dims);
int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
MACE_CHECK(output_size <= output->second.impl_->buffer_size)
<< "Output size exceeds buffer size: shape"
<< MakeString<int64_t>(shape) << " vs buffer size "
<< output->second.impl_->buffer_size;
output->second.impl_->shape = shape;
Tensor::MappingGuard output_guard(output_tensor);
const float *output_data = output_tensor->data<float>();
return ops::Transpose(output_data,
output_tensor->shape(),
dst_dims,
output->second.data().get());
} else if (device_->device_type() == DeviceType::GPU &&
output->second.shape().size() == 4 && output->second.shape().size() == 4 &&
output->second.data_format() != output_tensor->data_format()) { output->second.data_format() != output_tensor->data_format()) {
VLOG(1) << "Transform output " << output->first << " from " VLOG(1) << "Transform output " << output->first << " from "
<< output_tensor->data_format() << " to " << output_tensor->data_format() << " to "
<< output->second.data_format(); << output->second.data_format();
std::vector<int> dst_dims = {0, 3, 1, 2}; std::vector<int> dst_dims;
if (output_tensor->data_format() == NCHW) { if (output_tensor->data_format() == NCHW &&
output->second.data_format() == NHWC) {
dst_dims = {0, 2, 3, 1}; dst_dims = {0, 2, 3, 1};
} else if (output_tensor->data_format() == NHWC &&
output->second.data_format() == NCHW) {
dst_dims = {0, 3, 1, 2};
} else {
LOG(FATAL) <<"Not supported output data format: "
<< output->second.data_format() << " vs "
<< output_tensor->data_format();
} }
VLOG(1) << "Transform output " << output->first << " from "
<< output_tensor->data_format() << " to "
<< output->second.data_format();
std::vector<index_t> shape = std::vector<index_t> shape =
TransposeShape<index_t, index_t>(output_tensor->shape(), TransposeShape<index_t, index_t>(output_tensor->shape(),
dst_dims); dst_dims);
......
...@@ -35,8 +35,8 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation { ...@@ -35,8 +35,8 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
public: public:
explicit BiasAddOp(OpConstructContext *context) explicit BiasAddOp(OpConstructContext *context)
: Operation(context), : Operation(context),
data_format_(static_cast<DataFormat>(Operation::GetOptionalArg<int>( has_data_format_(Operation::GetOptionalArg<int>("has_data_format", 0))
"data_format", NHWC))) {} {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -57,7 +57,7 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation { ...@@ -57,7 +57,7 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
const float *bias_ptr = bias->data<float>(); const float *bias_ptr = bias->data<float>();
float *output_ptr = output->mutable_data<float>(); float *output_ptr = output->mutable_data<float>();
if (input->dim_size() == 4 && data_format_ == NCHW) { if (input->dim_size() == 4 && has_data_format_) {
const index_t batch = input->dim(0); const index_t batch = input->dim(0);
const index_t channels = input->dim(1); const index_t channels = input->dim(1);
const index_t height_width = input->dim(2) * input->dim(3); const index_t height_width = input->dim(2) * input->dim(3);
...@@ -90,7 +90,7 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation { ...@@ -90,7 +90,7 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
} }
private: private:
DataFormat data_format_; int has_data_format_;
}; };
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
...@@ -99,8 +99,7 @@ class BiasAddOp<DeviceType::GPU, T> : public Operation { ...@@ -99,8 +99,7 @@ class BiasAddOp<DeviceType::GPU, T> : public Operation {
public: public:
explicit BiasAddOp(OpConstructContext *context) explicit BiasAddOp(OpConstructContext *context)
: Operation(context), : Operation(context),
data_format_(static_cast<DataFormat>(Operation::GetOptionalArg<int>( has_data_format_(Operation::GetOptionalArg<int>("has_data_format", 1)) {
"data_format", NHWC))) {
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
...@@ -121,13 +120,13 @@ class BiasAddOp<DeviceType::GPU, T> : public Operation { ...@@ -121,13 +120,13 @@ class BiasAddOp<DeviceType::GPU, T> : public Operation {
Tensor *output = this->Output(0); Tensor *output = this->Output(0);
MACE_RETURN_IF_ERROR(output->ResizeLike(input)); MACE_RETURN_IF_ERROR(output->ResizeLike(input));
MACE_CHECK(input->dim_size() == 4 && data_format_ == NHWC, MACE_CHECK(input->dim_size() == 4 && has_data_format_,
"gpu only support biasadd for 4-dimensional NHWC format tensor"); "gpu only support biasadd for 4-dimensional NHWC format tensor");
return kernel_->Compute(context, input, bias, output); return kernel_->Compute(context, input, bias, output);
} }
private: private:
DataFormat data_format_; int has_data_format_;
std::unique_ptr<OpenCLBiasAddKernel> kernel_; std::unique_ptr<OpenCLBiasAddKernel> kernel_;
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
......
...@@ -42,7 +42,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { ...@@ -42,7 +42,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) {
OpDefBuilder("BiasAdd", "BiasAddBM") OpDefBuilder("BiasAdd", "BiasAddBM")
.Input("Input") .Input("Input")
.Input("Bias") .Input("Bias")
.AddIntArg("data_format", data_format) .AddIntArg("has_data_format", 1)
.Output("Output") .Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
...@@ -36,7 +36,7 @@ void BiasAddSimple() { ...@@ -36,7 +36,7 @@ void BiasAddSimple() {
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Bias") .Input("Bias")
.AddIntArg("data_format", NCHW) .AddIntArg("has_data_format", 1)
.Output("OutputNCHW") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -90,7 +90,7 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) { ...@@ -90,7 +90,7 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) {
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Bias") .Input("Bias")
.AddIntArg("data_format", NCHW) .AddIntArg("has_data_format", 1)
.Output("OutputNCHW") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
...@@ -139,7 +139,7 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { ...@@ -139,7 +139,7 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) {
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Bias") .Input("Bias")
.AddIntArg("data_format", NCHW) .AddIntArg("has_data_format", 1)
.Output("OutputNCHW") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
...@@ -39,14 +39,14 @@ class BufferTransformOp<DeviceType::GPU, T> : public Operation { ...@@ -39,14 +39,14 @@ class BufferTransformOp<DeviceType::GPU, T> : public Operation {
auto type = auto type =
static_cast<OpenCLBufferType>(Operation::GetOptionalArg<int>( static_cast<OpenCLBufferType>(Operation::GetOptionalArg<int>(
"buffer_type", static_cast<int>(CONV2D_FILTER))); "buffer_type", static_cast<int>(CONV2D_FILTER)));
auto data_format = static_cast<DataFormat>(Operation::GetOptionalArg<int>( bool has_data_format = Operation::GetOptionalArg<int>("has_data_format", 0)
"data_format", DataFormat::DF_NONE)); != 0;
MemoryType in_mem_type = context->workspace()->GetTensor( MemoryType in_mem_type = context->workspace()->GetTensor(
operator_def_->input(0))->memory_type(); operator_def_->input(0))->memory_type();
return OpenCLBufferTransformer<T>(in_mem_type, out_mem_type_).Transform( return OpenCLBufferTransformer<T>(in_mem_type, out_mem_type_).Transform(
context, input, type, out_mem_type_, wino_blk_size_, context, input, type, out_mem_type_, wino_blk_size_,
data_format, output); has_data_format, output);
} }
private: private:
......
...@@ -60,9 +60,9 @@ class ConcatOp<DeviceType::CPU, T> : public ConcatOpBase { ...@@ -60,9 +60,9 @@ class ConcatOp<DeviceType::CPU, T> : public ConcatOpBase {
MACE_UNUSED(context); MACE_UNUSED(context);
if (!checked_) { if (!checked_) {
Validate(); Validate();
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>( auto has_df = Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE)); "has_data_format", 0);
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) { if (has_df && this->Input(0)->dim_size() == 4) {
if (axis_ == 3) axis_ = 1; if (axis_ == 3) axis_ = 1;
else if (axis_ == 2) axis_ = 3; else if (axis_ == 2) axis_ = 3;
else if (axis_ == 1) axis_ = 2; else if (axis_ == 1) axis_ = 2;
...@@ -251,9 +251,12 @@ void RegisterConcat(OpRegistryBase *op_registry) { ...@@ -251,9 +251,12 @@ void RegisterConcat(OpRegistryBase *op_registry) {
if (op->output_shape(0).dims_size() != 4) { if (op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU }; return { DeviceType::CPU };
} else { } else {
int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0);
int axis = ProtoArgHelper::GetOptionalArg<OperatorDef, int>( int axis = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "axis", 3); *op, "axis", 3);
if (axis != 3) { if (!has_data_format || axis != 3) {
return { DeviceType::CPU }; return { DeviceType::CPU };
} }
bool divisible_four = true; bool divisible_four = true;
......
...@@ -91,6 +91,7 @@ void OpenCLConcatHelper(int iters, ...@@ -91,6 +91,7 @@ void OpenCLConcatHelper(int iters,
.Input("Input0") .Input("Input0")
.Input("Input1") .Input("Input1")
.AddIntArg("axis", concat_dim) .AddIntArg("axis", concat_dim)
.AddIntArg("has_data_format", 1)
.Output("Output") .Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
...@@ -100,11 +100,12 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) { ...@@ -100,11 +100,12 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) {
} }
} }
TEST_F(ConcatOpTest, CPURandom) { namespace {
void CPURandomTest(int input_dim, int has_data_format) {
static unsigned int seed = time(NULL); static unsigned int seed = time(NULL);
int dim = 5; int dim = input_dim;
int num_inputs = 2 + rand_r(&seed) % 10; int num_inputs = 2 + rand_r(&seed) % 10;
int axis = 1; int axis = 3;
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
auto builder = OpDefBuilder("Concat", "ConcatTest"); auto builder = OpDefBuilder("Concat", "ConcatTest");
...@@ -112,9 +113,13 @@ TEST_F(ConcatOpTest, CPURandom) { ...@@ -112,9 +113,13 @@ TEST_F(ConcatOpTest, CPURandom) {
builder = builder.Input(MakeString("Input", i)); builder = builder.Input(MakeString("Input", i));
} }
builder.AddIntArg("axis", axis) builder.AddIntArg("axis", axis)
.AddIntArg("has_data_format", has_data_format)
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
if (has_data_format) {
axis = 1;
}
std::vector<index_t> shape_data; std::vector<index_t> shape_data;
GenerateRandomIntTypeData<index_t>({dim}, &shape_data, 1, dim); GenerateRandomIntTypeData<index_t>({dim}, &shape_data, 1, dim);
std::vector<std::vector<index_t>> input_shapes(num_inputs, shape_data); std::vector<std::vector<index_t>> input_shapes(num_inputs, shape_data);
...@@ -152,6 +157,13 @@ TEST_F(ConcatOpTest, CPURandom) { ...@@ -152,6 +157,13 @@ TEST_F(ConcatOpTest, CPURandom) {
} }
} }
} }
} // namespace
TEST_F(ConcatOpTest, CPURandom) {
CPURandomTest(5, 0);
CPURandomTest(4, 0);
CPURandomTest(4, 1);
}
TEST_F(ConcatOpTest, QuantizedCPURandom) { TEST_F(ConcatOpTest, QuantizedCPURandom) {
static unsigned int seed = time(NULL); static unsigned int seed = time(NULL);
...@@ -186,7 +198,7 @@ TEST_F(ConcatOpTest, QuantizedCPURandom) { ...@@ -186,7 +198,7 @@ TEST_F(ConcatOpTest, QuantizedCPURandom) {
builder = builder.Input(MakeString("Input", i)); builder = builder.Input(MakeString("Input", i));
} }
builder.AddIntArg("axis", axis_arg) builder.AddIntArg("axis", axis_arg)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
...@@ -248,7 +260,7 @@ namespace { ...@@ -248,7 +260,7 @@ namespace {
template <typename T> template <typename T>
void OpenCLRandomTest(const std::vector<std::vector<index_t>> &shapes, void OpenCLRandomTest(const std::vector<std::vector<index_t>> &shapes,
const int axis, const int axis,
DataFormat data_format) { bool has_data_format) {
srand(time(nullptr)); srand(time(nullptr));
int num_inputs = shapes.size(); int num_inputs = shapes.size();
int concat_axis_size = 0; int concat_axis_size = 0;
...@@ -275,7 +287,7 @@ void OpenCLRandomTest(const std::vector<std::vector<index_t>> &shapes, ...@@ -275,7 +287,7 @@ void OpenCLRandomTest(const std::vector<std::vector<index_t>> &shapes,
builder.AddIntArg("axis", axis) builder.AddIntArg("axis", axis)
.Output("Output") .Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("data_format", data_format) .AddIntArg("has_data_format", has_data_format)
.OutputShape(expected_shape) .OutputShape(expected_shape)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
...@@ -309,38 +321,37 @@ void OpenCLRandomTest(const std::vector<std::vector<index_t>> &shapes, ...@@ -309,38 +321,37 @@ void OpenCLRandomTest(const std::vector<std::vector<index_t>> &shapes,
} // namespace } // namespace
TEST_F(ConcatOpTest, OPENCLAligned) { TEST_F(ConcatOpTest, OPENCLAligned) {
OpenCLRandomTest<float>({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3, OpenCLRandomTest<float>({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3, 1);
DataFormat::NHWC);
} }
TEST_F(ConcatOpTest, OPENCLHalfAligned) { TEST_F(ConcatOpTest, OPENCLHalfAligned) {
OpenCLRandomTest<half>({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3, OpenCLRandomTest<half>({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3, 1);
DataFormat::NHWC);
} }
TEST_F(ConcatOpTest, OPENCLUnAligned) { TEST_F(ConcatOpTest, OPENCLUnAligned) {
OpenCLRandomTest<float>({{3, 32, 32, 13}, {3, 32, 32, 17}}, 3, OpenCLRandomTest<float>({{3, 32, 32, 13}, {3, 32, 32, 17}}, 3, 1);
DataFormat::NHWC);
} }
TEST_F(ConcatOpTest, OPENCLAlignedMultiInput) { TEST_F(ConcatOpTest, OPENCLAlignedMultiInput) {
OpenCLRandomTest<float>( OpenCLRandomTest<float>(
{{3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}}, {{3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}},
3, DataFormat::NHWC); 3, 1);
} }
TEST_F(ConcatOpTest, GPUFallbackToCPU2DInput) { TEST_F(ConcatOpTest, GPUFallbackToCPU2DInput) {
OpenCLRandomTest<float>({{3, 4}, {3, 4}}, 1, DataFormat::DF_NONE); OpenCLRandomTest<float>({{3, 4}, {3, 4}}, 1, 0);
} }
TEST_F(ConcatOpTest, GPUFallbackToCPUChanNotDivisibleBy4) { TEST_F(ConcatOpTest, GPUFallbackToCPUChanNotDivisibleBy4) {
OpenCLRandomTest<float>({{1, 1, 4, 3}, {1, 1, 4, 3}}, 3, OpenCLRandomTest<float>({{1, 1, 4, 3}, {1, 1, 4, 3}}, 3, 0);
DataFormat::DF_NONE); }
TEST_F(ConcatOpTest, GPUFallbackToCPUNoDataFormat) {
OpenCLRandomTest<float>({{1, 1, 4, 4}, {1, 1, 4, 4}}, 3, 0);
} }
TEST_F(ConcatOpTest, GPUFallbackToCPUAxis2) { TEST_F(ConcatOpTest, GPUFallbackToCPUAxis2) {
OpenCLRandomTest<float>({{1, 1, 4, 3}, {1, 1, 4, 3}}, 2, OpenCLRandomTest<float>({{1, 1, 4, 3}, {1, 1, 4, 3}}, 2, 0);
DataFormat::DF_NONE);
} }
} // namespace test } // namespace test
......
...@@ -897,8 +897,8 @@ class EltwiseOp : public Operation { ...@@ -897,8 +897,8 @@ class EltwiseOp : public Operation {
scalar_input_(Operation::GetOptionalArg<float>("scalar_input", 1.0)), scalar_input_(Operation::GetOptionalArg<float>("scalar_input", 1.0)),
scalar_input_index_(Operation::GetOptionalArg<int32_t>( scalar_input_index_(Operation::GetOptionalArg<int32_t>(
"scalar_input_index", 1)), "scalar_input_index", 1)),
data_format_(static_cast<DataFormat>(Operation::GetOptionalArg<int>( has_data_format_(Operation::GetOptionalArg<int>(
"data_format", 0))) {} "has_data_format", 0)) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -940,7 +940,7 @@ class EltwiseOp : public Operation { ...@@ -940,7 +940,7 @@ class EltwiseOp : public Operation {
// check if we can broadcast tensor // check if we can broadcast tensor
uint32_t rank_diff = uint32_t rank_diff =
static_cast<uint32_t>(input0->dim_size() - input1->dim_size()); static_cast<uint32_t>(input0->dim_size() - input1->dim_size());
if (data_format_ == NCHW) { if (has_data_format_) {
MACE_CHECK( MACE_CHECK(
(input0->dim_size() == 4) && (input0->dim_size() == 4) &&
((input1->dim_size() == 0) || ((input1->dim_size() == 0) ||
...@@ -965,7 +965,7 @@ class EltwiseOp : public Operation { ...@@ -965,7 +965,7 @@ class EltwiseOp : public Operation {
const T *input0_ptr = input0->data<T>(); const T *input0_ptr = input0->data<T>();
const T *input1_ptr = input1->data<T>(); const T *input1_ptr = input1->data<T>();
if (data_format_ == NCHW && input1->dim_size() > 0) { if (has_data_format_ && input1->dim_size() > 0) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input0)); MACE_RETURN_IF_ERROR(output->ResizeLike(input0));
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
DstType *output_ptr = output->mutable_data<DstType>(); DstType *output_ptr = output->mutable_data<DstType>();
...@@ -1027,7 +1027,7 @@ class EltwiseOp : public Operation { ...@@ -1027,7 +1027,7 @@ class EltwiseOp : public Operation {
std::vector<float> coeff_; std::vector<float> coeff_;
float scalar_input_; float scalar_input_;
int32_t scalar_input_index_; int32_t scalar_input_index_;
DataFormat data_format_; int has_data_format_;
Tensor scalar_tensor_; Tensor scalar_tensor_;
}; };
...@@ -1042,9 +1042,7 @@ class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation { ...@@ -1042,9 +1042,7 @@ class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation {
coeff_(Operation::GetRepeatedArgs<float>("coeff")), coeff_(Operation::GetRepeatedArgs<float>("coeff")),
scalar_input_(Operation::GetOptionalArg<float>("scalar_input", 1.0)), scalar_input_(Operation::GetOptionalArg<float>("scalar_input", 1.0)),
scalar_input_index_(Operation::GetOptionalArg<int32_t>( scalar_input_index_(Operation::GetOptionalArg<int32_t>(
"scalar_input_index", 1)), "scalar_input_index", 1))
data_format_(static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", 0)))
#ifdef MACE_ENABLE_NEON #ifdef MACE_ENABLE_NEON
, eltwise_(static_cast<ops::EltwiseType>(Operation::GetOptionalArg<int>( , eltwise_(static_cast<ops::EltwiseType>(Operation::GetOptionalArg<int>(
"type", static_cast<int>(ops::EltwiseType::NONE)))) "type", static_cast<int>(ops::EltwiseType::NONE))))
...@@ -1139,7 +1137,6 @@ class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation { ...@@ -1139,7 +1137,6 @@ class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation {
std::vector<float> coeff_; std::vector<float> coeff_;
float scalar_input_; float scalar_input_;
int32_t scalar_input_index_; int32_t scalar_input_index_;
DataFormat data_format_;
Tensor scalar_tensor_; Tensor scalar_tensor_;
#ifdef MACE_ENABLE_NEON #ifdef MACE_ENABLE_NEON
arm::q8::Eltwise eltwise_; arm::q8::Eltwise eltwise_;
......
...@@ -44,6 +44,7 @@ void EltwiseBenchmark( ...@@ -44,6 +44,7 @@ void EltwiseBenchmark(
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", {1.2, 2.1}) .AddFloatsArg("coeff", {1.2, 2.1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("has_data_format", 1)
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
...@@ -75,7 +75,7 @@ void SimpleTensorScalar(const ops::EltwiseType type, ...@@ -75,7 +75,7 @@ void SimpleTensorScalar(const ops::EltwiseType type,
.AddIntArg("T", DataTypeToEnum<T>::v()) .AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("scalar_input", x) .AddFloatArg("scalar_input", x)
.AddIntArg("data_format", DataFormat::NCHW) .AddIntArg("has_data_format", 1)
.OutputType({ops::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) .OutputType({ops::IsLogicalType(type) ? DT_INT32 : DT_FLOAT})
.Output("TOutput") .Output("TOutput")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
...@@ -120,7 +120,7 @@ void SimpleTensorEltwise(const ops::EltwiseType type, ...@@ -120,7 +120,7 @@ void SimpleTensorEltwise(const ops::EltwiseType type,
.AddIntArg("T", DataTypeToEnum<T>::v()) .AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", coeff) .AddFloatsArg("coeff", coeff)
.AddIntArg("data_format", DataFormat::NCHW) .AddIntArg("has_data_format", 1)
.OutputType({ops::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) .OutputType({ops::IsLogicalType(type) ? DT_INT32 : DT_FLOAT})
.Output("TOutput"); .Output("TOutput");
if (shape0.size() > 1) { if (shape0.size() > 1) {
...@@ -642,7 +642,7 @@ void RandomTensorScalar(const ops::EltwiseType type, ...@@ -642,7 +642,7 @@ void RandomTensorScalar(const ops::EltwiseType type,
.Input("TInput") .Input("TInput")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("scalar_input", 0.1) .AddFloatArg("scalar_input", 0.1)
.AddIntArg("data_format", DataFormat::NCHW) .AddIntArg("has_data_format", 1)
.Output("TOutput") .Output("TOutput")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -699,7 +699,7 @@ void RandomTensorEltwise(const ops::EltwiseType type, ...@@ -699,7 +699,7 @@ void RandomTensorEltwise(const ops::EltwiseType type,
.Input("TInput1") .Input("TInput1")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", coeff) .AddFloatsArg("coeff", coeff)
.AddIntArg("data_format", DataFormat::NCHW) .AddIntArg("has_data_format", 1)
.Output("TOutput") .Output("TOutput")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
...@@ -755,7 +755,7 @@ void Quantized(const std::vector<index_t> &shape, ...@@ -755,7 +755,7 @@ void Quantized(const std::vector<index_t> &shape,
.Input("TInput0") .Input("TInput0")
.Input("TInput1") .Input("TInput1")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddIntArg("data_format", DataFormat::NCHW) .AddIntArg("has_data_format", 1)
.Output("TOutput") .Output("TOutput")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
...@@ -34,9 +34,9 @@ class InferConv2dShapeOp : public Operation { ...@@ -34,9 +34,9 @@ class InferConv2dShapeOp : public Operation {
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
int32_t *output_data = output->mutable_data<int32_t>(); int32_t *output_data = output->mutable_data<int32_t>();
const int32_t data_format = auto has_data_format =
Operation::GetOptionalArg<int>("data_format", 0); Operation::GetOptionalArg<int>("has_data_format", 0);
const bool isNCHW = data_format == 1; const bool isNCHW = (has_data_format && D == DeviceType::CPU);
Padding padding_type = Padding padding_type =
static_cast<Padding>(Operation::GetOptionalArg<int>( static_cast<Padding>(Operation::GetOptionalArg<int>(
......
...@@ -57,8 +57,8 @@ void TestInferConv2dShapeOp(const std::vector<index_t> &input_shape, ...@@ -57,8 +57,8 @@ void TestInferConv2dShapeOp(const std::vector<index_t> &input_shape,
} // namespace } // namespace
TEST_F(InferConv2dShapeOpTest, TestInferConv2dShape) { TEST_F(InferConv2dShapeOpTest, TestInferConv2dShape) {
TestInferConv2dShapeOp({3, 640, 480, 16}, 1, {3, 640, 480, 3}); TestInferConv2dShapeOp({3, 640, 480, 16}, 1, {3, 640, 480, 3});
TestInferConv2dShapeOp({3, 640, 480, 16}, 2, {3, 320, 240, 3}); TestInferConv2dShapeOp({3, 640, 480, 16}, 2, {3, 320, 240, 3});
} }
} // namespace test } // namespace test
......
...@@ -48,7 +48,7 @@ class OpenCLBufferTransformer { ...@@ -48,7 +48,7 @@ class OpenCLBufferTransformer {
const OpenCLBufferType type, const OpenCLBufferType type,
const MemoryType out_mem_type, const MemoryType out_mem_type,
const int wino_blk_size, const int wino_blk_size,
const DataFormat data_format, bool has_data_format,
Tensor *output) { Tensor *output) {
Workspace *ws = context->workspace(); Workspace *ws = context->workspace();
DataType dt = DataTypeToEnum<T>::value; DataType dt = DataTypeToEnum<T>::value;
...@@ -67,13 +67,14 @@ class OpenCLBufferTransformer { ...@@ -67,13 +67,14 @@ class OpenCLBufferTransformer {
VLOG(2) << "Transform CPU Buffer " << input->name() VLOG(2) << "Transform CPU Buffer " << input->name()
<< " to GPU Buffer " << internal_tensor->name() << " to GPU Buffer " << internal_tensor->name()
<< " with data type " << dt; << " with data type " << dt;
if (data_format == DataFormat::NHWC && input->shape().size() == 4) { if (has_data_format && input->shape().size() == 4) {
// 1. (NCHW -> NHWC) // 1. (NCHW -> NHWC)
std::vector<int> dst_dims = {0, 2, 3, 1}; std::vector<int> dst_dims = {0, 2, 3, 1};
std::vector<index_t> output_shape = std::vector<index_t> output_shape =
TransposeShape<index_t, index_t>(input->shape(), TransposeShape<index_t, index_t>(input->shape(),
dst_dims); dst_dims);
internal_tensor->Resize(output_shape); internal_tensor->Resize(output_shape);
internal_tensor->set_data_format(DataFormat::NHWC);
// TODO(liuqi): Only support float now // TODO(liuqi): Only support float now
const float *input_ptr = input->data<float>(); const float *input_ptr = input->data<float>();
Tensor::MappingGuard guard(internal_tensor); Tensor::MappingGuard guard(internal_tensor);
...@@ -105,13 +106,13 @@ class OpenCLBufferTransformer { ...@@ -105,13 +106,13 @@ class OpenCLBufferTransformer {
VLOG(2) << "Transform GPU Buffer " << internal_tensor.name() VLOG(2) << "Transform GPU Buffer " << internal_tensor.name()
<< " to CPU Buffer " << output->name() << " to CPU Buffer " << output->name()
<< " with data type " << dt; << " with data type " << dt;
if (data_format == DataFormat::NHWC && if (has_data_format && internal_tensor.shape().size() == 4) {
internal_tensor.shape().size() == 4) {
// NHWC -> NCHW // NHWC -> NCHW
std::vector<int> dst_dims = {0, 3, 1, 2}; std::vector<int> dst_dims = {0, 3, 1, 2};
std::vector<index_t> output_shape = std::vector<index_t> output_shape =
TransposeShape<index_t, index_t>(internal_tensor.shape(), TransposeShape<index_t, index_t>(internal_tensor.shape(),
dst_dims); dst_dims);
output->set_data_format(DataFormat::NCHW);
Tensor::MappingGuard guard(&internal_tensor); Tensor::MappingGuard guard(&internal_tensor);
const float *internal_ptr = internal_tensor.data<float>(); const float *internal_ptr = internal_tensor.data<float>();
output->Resize(output_shape); output->Resize(output_shape);
......
...@@ -166,9 +166,20 @@ bool OpsTestNet::Setup(mace::DeviceType device) { ...@@ -166,9 +166,20 @@ bool OpsTestNet::Setup(mace::DeviceType device) {
!ws_.GetTensor(input)->is_weight()) { !ws_.GetTensor(input)->is_weight()) {
auto input_info = net_def.add_input_info(); auto input_info = net_def.add_input_info();
input_info->set_name(input); input_info->set_name(input);
auto data_format = ProtoArgHelper::GetOptionalArg<OperatorDef, int>( auto has_data_format = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "data_format", DataFormat::DF_NONE); op_def, "has_data_format", 1);
input_info->set_data_format(data_format); auto is_quantized_op = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "T", static_cast<int>(DT_FLOAT))
== static_cast<int>(DT_UINT8);
if (has_data_format) {
if (is_quantized_op || device == DeviceType::GPU) {
input_info->set_data_format(NHWC);
} else {
input_info->set_data_format(NCHW);
}
} else {
input_info->set_data_format(DataFormat::DF_NONE);
}
auto &shape = ws_.GetTensor(input)->shape(); auto &shape = ws_.GetTensor(input)->shape();
for (auto d : shape) { for (auto d : shape) {
input_info->add_dims(static_cast<int>(d)); input_info->add_dims(static_cast<int>(d));
......
...@@ -40,9 +40,9 @@ class PadOp<DeviceType::CPU, T> : public Operation { ...@@ -40,9 +40,9 @@ class PadOp<DeviceType::CPU, T> : public Operation {
constant_value_(Operation::GetOptionalArg<float>( constant_value_(Operation::GetOptionalArg<float>(
"constant_value", 0.0)) { "constant_value", 0.0)) {
MACE_CHECK(paddings_.size() == 8); MACE_CHECK(paddings_.size() == 8);
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>( auto has_df = Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE)); "has_data_format", 0);
if (df == DataFormat::NHWC) { if (has_df) {
paddings_ = TransposeShape<int, int>(paddings_, {0, 1, 6, 7, 2, 3, 4, 5}); paddings_ = TransposeShape<int, int>(paddings_, {0, 1, 6, 7, 2, 3, 4, 5});
} }
} }
...@@ -55,11 +55,9 @@ class PadOp<DeviceType::CPU, T> : public Operation { ...@@ -55,11 +55,9 @@ class PadOp<DeviceType::CPU, T> : public Operation {
this->paddings_.size() == static_cast<size_t>(input->dim_size()) * 2); this->paddings_.size() == static_cast<size_t>(input->dim_size()) * 2);
auto input_shape = input->shape(); auto input_shape = input->shape();
for (size_t i = 0; i < paddings_.size(); ++i) { for (size_t i = 0; i < paddings_.size(); ++i) {
if (type_ == PadType::REFLECT) { if (type_ == PadType::REFLECT || type_ == PadType::SYMMETRIC) {
MACE_CHECK(paddings_[i] < input_shape[i / 2]); MACE_CHECK(paddings_[i] < input_shape[i / 2], paddings_[i],
" vs ", input_shape[i / 2]);
} else if (type_ == PadType::SYMMETRIC) {
MACE_CHECK(paddings_[i] <= input_shape[i / 2]);
} }
MACE_CHECK(paddings_[i] >= 0); MACE_CHECK(paddings_[i] >= 0);
} }
......
...@@ -29,7 +29,11 @@ void Pad(int iters, int batch, int height, ...@@ -29,7 +29,11 @@ void Pad(int iters, int batch, int height,
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
} else {
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
}
const std::vector<int> paddings = {0, 0, pad, pad, pad, pad, 0, 0}; const std::vector<int> paddings = {0, 0, pad, pad, pad, pad, 0, 0};
OpDefBuilder("Pad", "PadTest") OpDefBuilder("Pad", "PadTest")
...@@ -37,6 +41,7 @@ void Pad(int iters, int batch, int height, ...@@ -37,6 +41,7 @@ void Pad(int iters, int batch, int height,
.Output("Output") .Output("Output")
.AddIntsArg("paddings", paddings) .AddIntsArg("paddings", paddings)
.AddIntArg("pad_type", pad_type) .AddIntArg("pad_type", pad_type)
.AddIntArg("has_data_format", 1)
.AddFloatArg("constant_value", 1.0) .AddFloatArg("constant_value", 1.0)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
...@@ -39,7 +39,7 @@ void SimpleConstant() { ...@@ -39,7 +39,7 @@ void SimpleConstant() {
.Output("Output") .Output("Output")
.AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0}) .AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0})
.AddFloatArg("constant_value", 1.0) .AddFloatArg("constant_value", 1.0)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -52,7 +52,7 @@ void SimpleConstant() { ...@@ -52,7 +52,7 @@ void SimpleConstant() {
.Output("TOutput") .Output("TOutput")
.AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0}) .AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0})
.AddFloatArg("constant_value", 1.0) .AddFloatArg("constant_value", 1.0)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -101,7 +101,7 @@ void Result(const std::vector<index_t> &input_shape, ...@@ -101,7 +101,7 @@ void Result(const std::vector<index_t> &input_shape,
.Output(t_output) .Output(t_output)
.AddIntsArg("paddings", paddings) .AddIntsArg("paddings", paddings)
.AddIntArg("pad_type", static_cast<int>(pad_type)) .AddIntArg("pad_type", static_cast<int>(pad_type))
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -179,7 +179,7 @@ TEST_F(PadTest, ComplexCPU) { ...@@ -179,7 +179,7 @@ TEST_F(PadTest, ComplexCPU) {
.Output("TOutput") .Output("TOutput")
.AddIntsArg("paddings", {0, 0, 1, 1, 1, 1, 1, 1}) .AddIntsArg("paddings", {0, 0, 1, 1, 1, 1, 1, 1})
.AddFloatArg("constant_value", 1.0) .AddFloatArg("constant_value", 1.0)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -217,7 +217,7 @@ void Complex(const std::vector<index_t> &input_shape, ...@@ -217,7 +217,7 @@ void Complex(const std::vector<index_t> &input_shape,
.AddIntsArg("paddings", paddings) .AddIntsArg("paddings", paddings)
.AddIntArg("pad_type", pad_type) .AddIntArg("pad_type", pad_type)
.AddFloatArg("constant_value", 1.0) .AddFloatArg("constant_value", 1.0)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -234,7 +234,7 @@ void Complex(const std::vector<index_t> &input_shape, ...@@ -234,7 +234,7 @@ void Complex(const std::vector<index_t> &input_shape,
.AddIntsArg("paddings", paddings) .AddIntsArg("paddings", paddings)
.AddIntArg("pad_type", pad_type) .AddIntArg("pad_type", pad_type)
.AddFloatArg("constant_value", 1.0) .AddFloatArg("constant_value", 1.0)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
......
...@@ -94,9 +94,9 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase { ...@@ -94,9 +94,9 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
int index = axis_[i] >= 0 ? int index = axis_[i] >= 0 ?
axis_[i] : axis_[i] :
axis_[i] + input->dim_size(); axis_[i] + input->dim_size();
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>( auto has_df = Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE)); "has_data_format", 0);
if (df == DataFormat::NHWC && DataTypeToEnum<T>::value != DT_UINT8 if (has_df && DataTypeToEnum<T>::value != DT_UINT8
&& input->dim_size() == 4) { && input->dim_size() == 4) {
if (index == 1 || index == 2) index = index + 1; if (index == 1 || index == 2) index = index + 1;
else if (index == 3) index = 1; else if (index == 3) index = 1;
......
...@@ -38,6 +38,7 @@ void Reduce(int iters, int batch, int channels, ...@@ -38,6 +38,7 @@ void Reduce(int iters, int batch, int channels,
.Input("Input") .Input("Input")
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
.Output("OutputImage") .Output("OutputImage")
.AddIntArg("has_data_format", 1)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
...@@ -44,7 +44,7 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -44,7 +44,7 @@ void Simple(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0) .AddIntArg("keepdims", keepdims ? 1 : 0)
.AddIntArg("reduce_type", type) .AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Output("OutputNCHW") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -56,7 +56,7 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -56,7 +56,7 @@ void Simple(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0) .AddIntArg("keepdims", keepdims ? 1 : 0)
.AddIntArg("reduce_type", type) .AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -84,7 +84,7 @@ void Simple3D(const std::vector<index_t> &input_shape, ...@@ -84,7 +84,7 @@ void Simple3D(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0) .AddIntArg("keepdims", keepdims ? 1 : 0)
.AddIntArg("reduce_type", type) .AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -588,7 +588,7 @@ void RandomTest(const std::vector<index_t> &input_shape, ...@@ -588,7 +588,7 @@ void RandomTest(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
.AddIntArg("keepdims", 1) .AddIntArg("keepdims", 1)
.AddIntArg("reduce_type", type) .AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Output("OutputNCHW") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -600,7 +600,7 @@ void RandomTest(const std::vector<index_t> &input_shape, ...@@ -600,7 +600,7 @@ void RandomTest(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
.AddIntArg("keepdims", 1) .AddIntArg("keepdims", 1)
.AddIntArg("reduce_type", type) .AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Output("OPENCLOutput") .Output("OPENCLOutput")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -662,7 +662,7 @@ void TestQuant(const std::vector<index_t> &input_shape, ...@@ -662,7 +662,7 @@ void TestQuant(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
.AddIntArg("keepdims", 1) .AddIntArg("keepdims", 1)
.AddIntArg("reduce_type", type) .AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Output("OutputNCHW") .Output("OutputNCHW")
.AddIntArg("T", DT_FLOAT) .AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
...@@ -687,7 +687,7 @@ void TestQuant(const std::vector<index_t> &input_shape, ...@@ -687,7 +687,7 @@ void TestQuant(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
.AddIntArg("keepdims", 1) .AddIntArg("keepdims", 1)
.AddIntArg("reduce_type", type) .AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.AddIntArg("T", DT_UINT8) .AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(); net.RunOp();
......
...@@ -77,9 +77,9 @@ class ReshapeOp : public Operation { ...@@ -77,9 +77,9 @@ class ReshapeOp : public Operation {
} }
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
// NHWC -> NCHW // NHWC -> NCHW
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>( auto has_df = Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE)); "has_data_format", 0);
if (df == DataFormat::NHWC && D == DeviceType::CPU if (has_df && D == DeviceType::CPU
&& out_shape.size() == 4 && shape->is_weight()) { && out_shape.size() == 4 && shape->is_weight()) {
std::vector<int> dst_dims = {0, 3, 1, 2}; std::vector<int> dst_dims = {0, 3, 1, 2};
std::vector<index_t> out_shape_gpu = TransposeShape<index_t, index_t>( std::vector<index_t> out_shape_gpu = TransposeShape<index_t, index_t>(
......
...@@ -35,11 +35,10 @@ class ShapeOp : public Operation { ...@@ -35,11 +35,10 @@ class ShapeOp : public Operation {
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
int32_t *output_data = output->mutable_data<int32_t>(); int32_t *output_data = output->mutable_data<int32_t>();
const int data_format = auto has_df = Operation::GetOptionalArg<int>(
Operation::GetOptionalArg<int>("data_format", 0); "has_data_format", 0);
if (input->dim_size() == 4 && if (D == DeviceType::CPU &&
D == DeviceType::CPU && has_df && input->dim_size() == 4) {
data_format == DataFormat::NCHW) {
// transpose NCHW to NHWC for cpu runtime // transpose NCHW to NHWC for cpu runtime
output_data[0] = static_cast<int32_t>(input->dim(0)); output_data[0] = static_cast<int32_t>(input->dim(0));
output_data[1] = static_cast<int32_t>(input->dim(2)); output_data[1] = static_cast<int32_t>(input->dim(2));
......
...@@ -36,9 +36,9 @@ class SplitOp<DeviceType::CPU, T> : public Operation { ...@@ -36,9 +36,9 @@ class SplitOp<DeviceType::CPU, T> : public Operation {
checked_(false) {} checked_(false) {}
void Validate() { void Validate() {
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>( auto has_df = Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE)); "has_data_format", 0);
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) { if (has_df && this->Input(0)->dim_size() == 4) {
if (axis_ == 3) axis_ = 1; if (axis_ == 3) axis_ = 1;
else if (axis_ == 2) axis_ = 3; else if (axis_ == 2) axis_ = 3;
else if (axis_ == 1) axis_ = 2; else if (axis_ == 1) axis_ = 2;
......
...@@ -44,6 +44,7 @@ void BMSplitHelper(int iters, ...@@ -44,6 +44,7 @@ void BMSplitHelper(int iters,
} }
builder builder
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Warm-up // Warm-up
......
...@@ -54,7 +54,7 @@ void RandomTest(const int num_outputs, int axis) { ...@@ -54,7 +54,7 @@ void RandomTest(const int num_outputs, int axis) {
builder = builder.Output(MakeString("Output", i)); builder = builder.Output(MakeString("Output", i));
} }
builder.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) builder.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
......
...@@ -32,9 +32,9 @@ class SqueezeOp : public Operation { ...@@ -32,9 +32,9 @@ class SqueezeOp : public Operation {
MACE_UNUSED(context); MACE_UNUSED(context);
if (!checked_ && D == DeviceType::CPU if (!checked_ && D == DeviceType::CPU
&& DataTypeToEnum<T>::value != DT_UINT8) { && DataTypeToEnum<T>::value != DT_UINT8) {
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>( auto has_df = Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE)); "has_data_format", 0);
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) { if (has_df && this->Input(0)->dim_size() == 4) {
if (axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2) { if (axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2) {
axis_[0] = 2; axis_[0] = 2;
axis_[1] = 3; axis_[1] = 3;
......
...@@ -30,7 +30,7 @@ void TestSqueeze(const std::vector<index_t> &org_shape, ...@@ -30,7 +30,7 @@ void TestSqueeze(const std::vector<index_t> &org_shape,
OpDefBuilder("Squeeze", "SqueezeTest") OpDefBuilder("Squeeze", "SqueezeTest")
.Input("Input") .Input("Input")
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
.AddIntArg("data_format", DataFormat::NHWC) .AddIntArg("has_data_format", 1)
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
...@@ -42,6 +42,7 @@ device_type_map = {'cpu': cvt.DeviceType.CPU.value, ...@@ -42,6 +42,7 @@ device_type_map = {'cpu': cvt.DeviceType.CPU.value,
data_format_map = { data_format_map = {
'NONE': cvt.DataFormat.DF_NONE, 'NONE': cvt.DataFormat.DF_NONE,
'NHWC': cvt.DataFormat.NHWC, 'NHWC': cvt.DataFormat.NHWC,
'NCHW': cvt.DataFormat.NCHW,
} }
...@@ -74,6 +75,13 @@ def parse_float_array_from_str(ints_str): ...@@ -74,6 +75,13 @@ def parse_float_array_from_str(ints_str):
return [float(int_str) for int_str in ints_str.split(',')] return [float(int_str) for int_str in ints_str.split(',')]
def transpose_shape(shape, dst_order):
t_shape = [0] * len(shape)
for i in range(len(shape)):
t_shape[i] = shape[dst_order[i]]
return t_shape
def main(unused_args): def main(unused_args):
if not os.path.isfile(FLAGS.model_file): if not os.path.isfile(FLAGS.model_file):
six.print_("Input graph file '" + six.print_("Input graph file '" +
...@@ -139,6 +147,10 @@ def main(unused_args): ...@@ -139,6 +147,10 @@ def main(unused_args):
else: else:
input_node.data_format = data_format_map[input_node_formats[i]] input_node.data_format = data_format_map[input_node_formats[i]]
input_node.shape = parse_int_array_from_str(input_node_shapes[i]) input_node.shape = parse_int_array_from_str(input_node_shapes[i])
if input_node.data_format == cvt.DataFormat.NCHW and\
len(input_node.shape) == 4:
input_node.shape = transpose_shape(input_node.shape, [0, 2, 3, 1])
input_node.data_format = cvt.DataFormat.NHWC
if len(input_node_ranges) > i: if len(input_node_ranges) > i:
input_node.range = parse_float_array_from_str(input_node_ranges[i]) input_node.range = parse_float_array_from_str(input_node_ranges[i])
option.add_input_node(input_node) option.add_input_node(input_node)
...@@ -156,6 +168,11 @@ def main(unused_args): ...@@ -156,6 +168,11 @@ def main(unused_args):
else: else:
output_node.data_format = data_format_map[output_node_formats[i]] output_node.data_format = data_format_map[output_node_formats[i]]
output_node.shape = parse_int_array_from_str(output_node_shapes[i]) output_node.shape = parse_int_array_from_str(output_node_shapes[i])
if output_node.data_format == cvt.DataFormat.NCHW and\
len(output_node.shape) == 4:
output_node.shape = transpose_shape(output_node.shape,
[0, 2, 3, 1])
output_node.data_format = cvt.DataFormat.NHWC
option.add_output_node(output_node) option.add_output_node(output_node)
if FLAGS.check_node != '': if FLAGS.check_node != '':
......
...@@ -181,6 +181,7 @@ class MaceKeyword(object): ...@@ -181,6 +181,7 @@ class MaceKeyword(object):
mace_global_pooling_str = 'global_pooling' mace_global_pooling_str = 'global_pooling'
mace_kernel_str = 'kernels' mace_kernel_str = 'kernels'
mace_data_format_str = 'data_format' mace_data_format_str = 'data_format'
mace_has_data_format_str = 'has_data_format'
mace_filter_format_str = 'filter_format' mace_filter_format_str = 'filter_format'
mace_element_type_str = 'type' mace_element_type_str = 'type'
mace_activation_type_str = 'activation' mace_activation_type_str = 'activation'
...@@ -525,6 +526,16 @@ class ConverterUtil(object): ...@@ -525,6 +526,16 @@ class ConverterUtil(object):
return arg return arg
return None return None
@staticmethod
def del_arg(op, arg_name):
found_idx = -1
for idx in range(len(op.arg)):
if op.arg[idx].name == arg_name:
found_idx = idx
break
if found_idx != -1:
del op.arg[found_idx]
@staticmethod @staticmethod
def add_data_format_arg(op, data_format): def add_data_format_arg(op, data_format):
data_format_arg = op.arg.add() data_format_arg = op.arg.add()
......
...@@ -1406,21 +1406,17 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1406,21 +1406,17 @@ class Transformer(base_converter.ConverterInterface):
def update_data_format(self): def update_data_format(self):
print("update data format") print("update data format")
data_format_flag = DataFormat.NHWC.value data_format_flag = 1
for input_node in self._option.input_nodes.values(): for input_node in self._option.input_nodes.values():
if input_node.data_format.value == DataFormat.DF_NONE.value: if input_node.data_format.value == DataFormat.DF_NONE.value:
data_format_flag = DataFormat.DF_NONE.value data_format_flag = 0
net = self._model net = self._model
for op in net.op: for op in net.op:
data_format_arg = ConverterUtil.get_arg( ConverterUtil.del_arg(
op, MaceKeyword.mace_data_format_str) op, MaceKeyword.mace_data_format_str)
if not data_format_arg: has_data_format_arg = op.arg.add()
data_format_arg = op.arg.add() has_data_format_arg.name = MaceKeyword.mace_has_data_format_str
data_format_arg.name = MaceKeyword.mace_data_format_str has_data_format_arg.i = data_format_flag
data_format_arg.i = data_format_flag
elif data_format_arg.i != data_format_flag:
data_format_arg.i = data_format_flag
return False return False
def quantize_nodes(self): def quantize_nodes(self):
......
...@@ -46,6 +46,7 @@ void MaceRunFunc(const int in_out_size) { ...@@ -46,6 +46,7 @@ void MaceRunFunc(const int in_out_size) {
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
InputInfo *info = net_def->add_input_info(); InputInfo *info = net_def->add_input_info();
info->set_data_format(DataFormat::NHWC);
info->set_name(input_names[i]); info->set_name(input_names[i]);
for (auto d : input_shapes[0]) { for (auto d : input_shapes[0]) {
info->add_dims(static_cast<int>(d)); info->add_dims(static_cast<int>(d));
......
...@@ -45,6 +45,7 @@ void MaceRun(const int in_out_size, ...@@ -45,6 +45,7 @@ void MaceRun(const int in_out_size,
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
InputInfo *info = net_def->add_input_info(); InputInfo *info = net_def->add_input_info();
info->set_data_format(DataFormat::NHWC);
info->set_name(input_names[i]); info->set_name(input_names[i]);
for (auto d : max_shape) { for (auto d : max_shape) {
info->add_dims(static_cast<int>(d)); info->add_dims(static_cast<int>(d));
......
...@@ -76,6 +76,7 @@ void Conv3x3(const std::string &input_name, ...@@ -76,6 +76,7 @@ void Conv3x3(const std::string &input_name,
.AddIntArg("padding", Padding::SAME) .AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("has_data_format", 1)
.Finalize(&operator_def); .Finalize(&operator_def);
OutputShape *shape = operator_def.add_output_shape(); OutputShape *shape = operator_def.add_output_shape();
...@@ -98,6 +99,7 @@ void Relu(const std::string &input_name, ...@@ -98,6 +99,7 @@ void Relu(const std::string &input_name,
.AddStringArg("activation", "RELU") .AddStringArg("activation", "RELU")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type)) .AddIntArg("device", static_cast<int>(device_type))
.AddIntArg("has_data_format", 1)
.Finalize(&operator_def); .Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def); net_def->add_op()->CopyFrom(operator_def);
......
...@@ -103,6 +103,16 @@ DeviceType ParseDeviceType(const std::string &device_str) { ...@@ -103,6 +103,16 @@ DeviceType ParseDeviceType(const std::string &device_str) {
} }
} }
DataFormat ParseDataFormat(const std::string &data_format_str) {
if (data_format_str == "NHWC") {
return DataFormat::NHWC;
} else if (data_format_str == "NCHW") {
return DataFormat::NCHW;
} else {
return DataFormat::DF_NONE;
}
}
struct mallinfo LogMallinfoChange(struct mallinfo prev) { struct mallinfo LogMallinfoChange(struct mallinfo prev) {
struct mallinfo curr = mallinfo(); struct mallinfo curr = mallinfo();
if (prev.arena != curr.arena) { if (prev.arena != curr.arena) {
...@@ -168,6 +178,12 @@ DEFINE_string(output_node, ...@@ -168,6 +178,12 @@ DEFINE_string(output_node,
DEFINE_string(output_shape, DEFINE_string(output_shape,
"1,224,224,2:1,1,1,10", "1,224,224,2:1,1,1,10",
"output shapes, separated by colon and comma"); "output shapes, separated by colon and comma");
DEFINE_string(input_data_format,
"NHWC",
"input data formats, NONE|NHWC|NCHW");
DEFINE_string(output_data_format,
"NHWC",
"output data formats, NONE|NHWC|NCHW");
DEFINE_string(input_file, DEFINE_string(input_file,
"", "",
"input file name | input file prefix for multiple inputs."); "input file name | input file prefix for multiple inputs.");
...@@ -206,8 +222,10 @@ DEFINE_int32(cpu_affinity_policy, 1, ...@@ -206,8 +222,10 @@ DEFINE_int32(cpu_affinity_policy, 1,
bool RunModel(const std::string &model_name, bool RunModel(const std::string &model_name,
const std::vector<std::string> &input_names, const std::vector<std::string> &input_names,
const std::vector<std::vector<int64_t>> &input_shapes, const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<DataFormat> &input_data_formats,
const std::vector<std::string> &output_names, const std::vector<std::string> &output_names,
const std::vector<std::vector<int64_t>> &output_shapes) { const std::vector<std::vector<int64_t>> &output_shapes,
const std::vector<DataFormat> &output_data_formats) {
DeviceType device_type = ParseDeviceType(FLAGS_device); DeviceType device_type = ParseDeviceType(FLAGS_device);
int64_t t0 = NowMicros(); int64_t t0 = NowMicros();
...@@ -325,7 +343,8 @@ bool RunModel(const std::string &model_name, ...@@ -325,7 +343,8 @@ bool RunModel(const std::string &model_name,
LOG(INFO) << "Open input file failed"; LOG(INFO) << "Open input file failed";
return -1; return -1;
} }
inputs[input_names[i]] = mace::MaceTensor(input_shapes[i], buffer_in); inputs[input_names[i]] = mace::MaceTensor(input_shapes[i], buffer_in,
input_data_formats[i]);
} }
for (size_t i = 0; i < output_count; ++i) { for (size_t i = 0; i < output_count; ++i) {
...@@ -334,7 +353,8 @@ bool RunModel(const std::string &model_name, ...@@ -334,7 +353,8 @@ bool RunModel(const std::string &model_name,
std::multiplies<int64_t>()); std::multiplies<int64_t>());
auto buffer_out = std::shared_ptr<float>(new float[output_size], auto buffer_out = std::shared_ptr<float>(new float[output_size],
std::default_delete<float[]>()); std::default_delete<float[]>());
outputs[output_names[i]] = mace::MaceTensor(output_shapes[i], buffer_out); outputs[output_names[i]] = mace::MaceTensor(output_shapes[i], buffer_out,
output_data_formats[i]);
} }
LOG(INFO) << "Warm up run"; LOG(INFO) << "Warm up run";
...@@ -498,13 +518,27 @@ int Main(int argc, char **argv) { ...@@ -498,13 +518,27 @@ int Main(int argc, char **argv) {
for (size_t i = 0; i < output_count; ++i) { for (size_t i = 0; i < output_count; ++i) {
ParseShape(output_shapes[i], &output_shape_vec[i]); ParseShape(output_shapes[i], &output_shape_vec[i]);
} }
std::vector<std::string> raw_input_data_formats =
str_util::Split(FLAGS_input_data_format, ',');
std::vector<std::string> raw_output_data_formats =
str_util::Split(FLAGS_output_data_format, ',');
std::vector<DataFormat> input_data_formats(input_count);
std::vector<DataFormat> output_data_formats(output_count);
for (size_t i = 0; i < input_count; ++i) {
input_data_formats[i] = ParseDataFormat(raw_input_data_formats[i]);
}
for (size_t i = 0; i < output_count; ++i) {
output_data_formats[i] = ParseDataFormat(raw_output_data_formats[i]);
}
bool ret = false; bool ret = false;
for (int i = 0; i < FLAGS_restart_round; ++i) { for (int i = 0; i < FLAGS_restart_round; ++i) {
VLOG(0) << "restart round " << i; VLOG(0) << "restart round " << i;
ret = ret =
RunModel(FLAGS_model_name, input_names, input_shape_vec, RunModel(FLAGS_model_name,
output_names, output_shape_vec); input_names, input_shape_vec, input_data_formats,
output_names, output_shape_vec, output_data_formats);
} }
if (ret) { if (ret) {
return 0; return 0;
......
...@@ -131,6 +131,12 @@ class DeviceType(object): ...@@ -131,6 +131,12 @@ class DeviceType(object):
HEXAGON = 'HEXAGON' HEXAGON = 'HEXAGON'
class DataFormat(object):
NONE = "NONE"
NHWC = "NHWC"
NCHW = "NCHW"
################################ ################################
# Argument types # Argument types
################################ ################################
......
...@@ -96,14 +96,10 @@ WinogradParameters = [0, 2, 4] ...@@ -96,14 +96,10 @@ WinogradParameters = [0, 2, 4]
DataFormatStrs = [ DataFormatStrs = [
"NONE", "NONE",
"NHWC", "NHWC",
"NCHW",
] ]
class DataFormat(object):
NONE = "NONE"
NHWC = "NHWC"
class DefaultValues(object): class DefaultValues(object):
mace_lib_type = MACELibType.static mace_lib_type = MACELibType.static
omp_num_threads = -1, omp_num_threads = -1,
...@@ -371,6 +367,15 @@ def format_model_config(flags): ...@@ -371,6 +367,15 @@ def format_model_config(flags):
if not isinstance(value, list): if not isinstance(value, list):
subgraph[key] = [value] subgraph[key] = [value]
subgraph[key] = [str(v) for v in subgraph[key]] subgraph[key] = [str(v) for v in subgraph[key]]
input_size = len(subgraph[YAMLKeyword.input_tensors])
output_size = len(subgraph[YAMLKeyword.output_tensors])
mace_check(len(subgraph[YAMLKeyword.input_shapes]) == input_size,
ModuleName.YAML_CONFIG,
"input shapes' size not equal inputs' size.")
mace_check(len(subgraph[YAMLKeyword.output_shapes]) == output_size,
ModuleName.YAML_CONFIG,
"output shapes' size not equal outputs' size.")
for key in [YAMLKeyword.check_tensors, for key in [YAMLKeyword.check_tensors,
YAMLKeyword.check_shapes]: YAMLKeyword.check_shapes]:
...@@ -399,13 +404,13 @@ def format_model_config(flags): ...@@ -399,13 +404,13 @@ def format_model_config(flags):
if input_data_formats: if input_data_formats:
if not isinstance(input_data_formats, list): if not isinstance(input_data_formats, list):
subgraph[YAMLKeyword.input_data_formats] =\ subgraph[YAMLKeyword.input_data_formats] =\
[input_data_formats] [input_data_formats] * input_size
else: else:
mace_check(len(input_data_formats) mace_check(len(input_data_formats)
== len(subgraph[YAMLKeyword.input_tensors]), == input_size,
ModuleName.YAML_CONFIG, ModuleName.YAML_CONFIG,
"input_data_formats should match" "input_data_formats should match"
" the size of input") " the size of input.")
for input_data_format in\ for input_data_format in\
subgraph[YAMLKeyword.input_data_formats]: subgraph[YAMLKeyword.input_data_formats]:
mace_check(input_data_format in DataFormatStrs, mace_check(input_data_format in DataFormatStrs,
...@@ -414,17 +419,18 @@ def format_model_config(flags): ...@@ -414,17 +419,18 @@ def format_model_config(flags):
+ str(DataFormatStrs) + ", but got " + str(DataFormatStrs) + ", but got "
+ input_data_format) + input_data_format)
else: else:
subgraph[YAMLKeyword.input_data_formats] = [DataFormat.NHWC] subgraph[YAMLKeyword.input_data_formats] = \
[DataFormat.NHWC] * input_size
output_data_formats = subgraph.get(YAMLKeyword.output_data_formats, output_data_formats = subgraph.get(YAMLKeyword.output_data_formats,
[]) [])
if output_data_formats: if output_data_formats:
if not isinstance(output_data_formats, list): if not isinstance(output_data_formats, list):
subgraph[YAMLKeyword.output_data_formats] = \ subgraph[YAMLKeyword.output_data_formats] = \
[output_data_formats] [output_data_formats] * output_size
else: else:
mace_check(len(output_data_formats) mace_check(len(output_data_formats)
== len(subgraph[YAMLKeyword.output_tensors]), == output_size,
ModuleName.YAML_CONFIG, ModuleName.YAML_CONFIG,
"output_data_formats should match" "output_data_formats should match"
" the size of output") " the size of output")
...@@ -435,7 +441,8 @@ def format_model_config(flags): ...@@ -435,7 +441,8 @@ def format_model_config(flags):
"'output_data_formats' must be in " "'output_data_formats' must be in "
+ str(DataFormatStrs)) + str(DataFormatStrs))
else: else:
subgraph[YAMLKeyword.output_data_formats] = [DataFormat.NHWC] subgraph[YAMLKeyword.output_data_formats] =\
[DataFormat.NHWC] * output_size
validation_threshold = subgraph.get( validation_threshold = subgraph.get(
YAMLKeyword.validation_threshold, {}) YAMLKeyword.validation_threshold, {})
......
...@@ -154,7 +154,9 @@ class DeviceWrapper: ...@@ -154,7 +154,9 @@ class DeviceWrapper:
input_nodes, input_nodes,
output_nodes, output_nodes,
input_shapes, input_shapes,
input_data_formats,
output_shapes, output_shapes,
output_data_formats,
mace_model_dir, mace_model_dir,
model_tag, model_tag,
device_type, device_type,
...@@ -216,6 +218,8 @@ class DeviceWrapper: ...@@ -216,6 +218,8 @@ class DeviceWrapper:
"--output_node=%s" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes), "--output_shape=%s" % ":".join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats),
"--output_data_format=%s" % ",".join(output_data_formats),
"--input_file=%s/%s" % (model_output_dir, "--input_file=%s/%s" % (model_output_dir,
input_file_name), input_file_name),
"--output_file=%s/%s" % (model_output_dir, "--output_file=%s/%s" % (model_output_dir,
...@@ -307,6 +311,8 @@ class DeviceWrapper: ...@@ -307,6 +311,8 @@ class DeviceWrapper:
"--output_node=%s" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes), "--output_shape=%s" % ":".join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats),
"--output_data_format=%s" % ",".join(output_data_formats),
"--input_file=%s/%s" % (self.data_dir, input_file_name), "--input_file=%s/%s" % (self.data_dir, input_file_name),
"--output_file=%s/%s" % (self.data_dir, output_file_name), "--output_file=%s/%s" % (self.data_dir, output_file_name),
"--input_dir=%s" % input_dir, "--input_dir=%s" % input_dir,
...@@ -394,6 +400,8 @@ class DeviceWrapper: ...@@ -394,6 +400,8 @@ class DeviceWrapper:
output_nodes=subgraphs[0][YAMLKeyword.output_tensors], output_nodes=subgraphs[0][YAMLKeyword.output_tensors],
input_shapes=subgraphs[0][YAMLKeyword.input_shapes], input_shapes=subgraphs[0][YAMLKeyword.input_shapes],
output_shapes=subgraphs[0][YAMLKeyword.output_shapes], output_shapes=subgraphs[0][YAMLKeyword.output_shapes],
input_data_formats=subgraphs[0][YAMLKeyword.input_data_formats],
output_data_formats=subgraphs[0][YAMLKeyword.output_data_formats],
mace_model_dir=mace_model_dir, mace_model_dir=mace_model_dir,
model_tag=model_name, model_tag=model_name,
device_type=DeviceType.GPU, device_type=DeviceType.GPU,
...@@ -587,6 +595,10 @@ class DeviceWrapper: ...@@ -587,6 +595,10 @@ class DeviceWrapper:
YAMLKeyword.output_tensors], YAMLKeyword.output_tensors],
input_shapes=subgraphs[0][YAMLKeyword.input_shapes], input_shapes=subgraphs[0][YAMLKeyword.input_shapes],
output_shapes=output_config[YAMLKeyword.output_shapes], output_shapes=output_config[YAMLKeyword.output_shapes],
input_data_formats=subgraphs[0][
YAMLKeyword.input_data_formats],
output_data_formats=subgraphs[0][
YAMLKeyword.output_data_formats],
mace_model_dir=mace_model_dir, mace_model_dir=mace_model_dir,
model_tag=model_name, model_tag=model_name,
device_type=device_type, device_type=device_type,
...@@ -652,6 +664,10 @@ class DeviceWrapper: ...@@ -652,6 +664,10 @@ class DeviceWrapper:
YAMLKeyword.input_shapes], YAMLKeyword.input_shapes],
output_shapes=output_config[ output_shapes=output_config[
YAMLKeyword.output_shapes], YAMLKeyword.output_shapes],
input_data_formats=subgraphs[0][
YAMLKeyword.input_data_formats],
output_data_formats=subgraphs[0][
YAMLKeyword.output_data_formats],
model_output_dir=model_output_dir, model_output_dir=model_output_dir,
input_data_types=subgraphs[0][ input_data_types=subgraphs[0][
YAMLKeyword.input_data_types], YAMLKeyword.input_data_types],
...@@ -750,6 +766,8 @@ class DeviceWrapper: ...@@ -750,6 +766,8 @@ class DeviceWrapper:
output_nodes, output_nodes,
input_shapes, input_shapes,
output_shapes, output_shapes,
input_data_formats,
output_data_formats,
max_num_runs, max_num_runs,
max_seconds, max_seconds,
model_tag, model_tag,
...@@ -790,6 +808,8 @@ class DeviceWrapper: ...@@ -790,6 +808,8 @@ class DeviceWrapper:
'--output_node=%s' % ','.join(output_nodes), '--output_node=%s' % ','.join(output_nodes),
'--input_shape=%s' % ':'.join(input_shapes), '--input_shape=%s' % ':'.join(input_shapes),
'--output_shape=%s' % ':'.join(output_shapes), '--output_shape=%s' % ':'.join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats),
"--output_data_format=%s" % ",".join(output_data_formats),
'--input_file=%s/%s' % (model_output_dir, input_file_name), '--input_file=%s/%s' % (model_output_dir, input_file_name),
"--model_data_file=%s" % model_data_file, "--model_data_file=%s" % model_data_file,
'--max_num_runs=%d' % max_num_runs, '--max_num_runs=%d' % max_num_runs,
...@@ -845,6 +865,8 @@ class DeviceWrapper: ...@@ -845,6 +865,8 @@ class DeviceWrapper:
'--output_node=%s' % ','.join(output_nodes), '--output_node=%s' % ','.join(output_nodes),
'--input_shape=%s' % ':'.join(input_shapes), '--input_shape=%s' % ':'.join(input_shapes),
'--output_shape=%s' % ':'.join(output_shapes), '--output_shape=%s' % ':'.join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats),
"--output_data_format=%s" % ",".join(output_data_formats),
'--input_file=%s/%s' % (self.data_dir, input_file_name), '--input_file=%s/%s' % (self.data_dir, input_file_name),
"--model_data_file=%s" % model_data_file, "--model_data_file=%s" % model_data_file,
'--max_num_runs=%d' % max_num_runs, '--max_num_runs=%d' % max_num_runs,
...@@ -961,6 +983,10 @@ class DeviceWrapper: ...@@ -961,6 +983,10 @@ class DeviceWrapper:
output_nodes=output_nodes, output_nodes=output_nodes,
input_shapes=subgraphs[0][YAMLKeyword.input_shapes], input_shapes=subgraphs[0][YAMLKeyword.input_shapes],
output_shapes=output_shapes, output_shapes=output_shapes,
input_data_formats=subgraphs[0][
YAMLKeyword.input_data_formats],
output_data_formats=subgraphs[0][
YAMLKeyword.output_data_formats],
max_num_runs=flags.max_num_runs, max_num_runs=flags.max_num_runs,
max_seconds=flags.max_seconds, max_seconds=flags.max_seconds,
mace_model_dir=mace_model_dir, mace_model_dir=mace_model_dir,
...@@ -974,8 +1000,7 @@ class DeviceWrapper: ...@@ -974,8 +1000,7 @@ class DeviceWrapper:
opencl_binary_file=opencl_output_bin_path, opencl_binary_file=opencl_output_bin_path,
opencl_parameter_file=opencl_parameter_path, opencl_parameter_file=opencl_parameter_path,
libmace_dynamic_library_path=LIBMACE_DYNAMIC_PATH, libmace_dynamic_library_path=LIBMACE_DYNAMIC_PATH,
link_dynamic=link_dynamic link_dynamic=link_dynamic)
)
def run(self, def run(self,
abi, abi,
......
...@@ -649,6 +649,8 @@ def validate_model(abi, ...@@ -649,6 +649,8 @@ def validate_model(abi,
output_nodes, output_nodes,
input_shapes, input_shapes,
output_shapes, output_shapes,
input_data_formats,
output_data_formats,
model_output_dir, model_output_dir,
input_data_types, input_data_types,
caffe_env, caffe_env,
...@@ -671,20 +673,12 @@ def validate_model(abi, ...@@ -671,20 +673,12 @@ def validate_model(abi,
sh.rm("-rf", "%s/%s" % (model_output_dir, formatted_name)) sh.rm("-rf", "%s/%s" % (model_output_dir, formatted_name))
device.pull_from_data_dir(formatted_name, model_output_dir) device.pull_from_data_dir(formatted_name, model_output_dir)
if platform == "tensorflow": if platform == "tensorflow" or platform == "onnx":
validate(platform, model_file_path, "",
"%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), device_type,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend,
validation_outputs_data,
log_file)
elif platform == "onnx":
validate(platform, model_file_path, "", validate(platform, model_file_path, "",
"%s/%s" % (model_output_dir, input_file_name), "%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), device_type, "%s/%s" % (model_output_dir, output_file_name), device_type,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_data_formats), ",".join(output_data_formats),
",".join(input_nodes), ",".join(output_nodes), ",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend, validation_threshold, ",".join(input_data_types), backend,
validation_outputs_data, validation_outputs_data,
...@@ -703,6 +697,8 @@ def validate_model(abi, ...@@ -703,6 +697,8 @@ def validate_model(abi,
"%s/%s" % (model_output_dir, output_file_name), "%s/%s" % (model_output_dir, output_file_name),
device_type, device_type,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_data_formats),
",".join(output_data_formats),
",".join(input_nodes), ",".join(output_nodes), ",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend, validation_threshold, ",".join(input_data_types), backend,
validation_outputs_data, validation_outputs_data,
...@@ -770,6 +766,8 @@ def validate_model(abi, ...@@ -770,6 +766,8 @@ def validate_model(abi,
"--output_node=%s" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes), "--output_shape=%s" % ":".join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats),
"--output_data_format=%s" % ",".join(output_data_formats),
"--validation_threshold=%f" % validation_threshold, "--validation_threshold=%f" % validation_threshold,
"--input_data_type=%s" % ",".join(input_data_types), "--input_data_type=%s" % ",".join(input_data_types),
"--backend=%s" % ",".join(backend), "--backend=%s" % ",".join(backend),
......
...@@ -148,10 +148,11 @@ def validate_with_file(platform, device_type, ...@@ -148,10 +148,11 @@ def validate_with_file(platform, device_type,
value, validation_threshold, log_file) value, validation_threshold, log_file)
def validate_tf_model(platform, device_type, model_file, input_file, def validate_tf_model(platform, device_type, model_file,
mace_out_file, input_names, input_shapes, input_file, mace_out_file,
output_names, validation_threshold, input_data_types, input_names, input_shapes, input_data_formats,
log_file): output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types, log_file):
import tensorflow as tf import tensorflow as tf
if not os.path.isfile(model_file): if not os.path.isfile(model_file):
common.MaceLogger.error( common.MaceLogger.error(
...@@ -174,6 +175,9 @@ def validate_tf_model(platform, device_type, model_file, input_file, ...@@ -174,6 +175,9 @@ def validate_tf_model(platform, device_type, model_file, input_file,
common.formatted_file_name(input_file, input_names[i]), common.formatted_file_name(input_file, input_names[i]),
input_data_types[i]) input_data_types[i])
input_value = input_value.reshape(input_shapes[i]) input_value = input_value.reshape(input_shapes[i])
if input_data_formats[i] == common.DataFormat.NCHW and\
len(input_shapes[i]) == 4:
input_value = input_value.transpose((0, 2, 3, 1))
input_node = graph.get_tensor_by_name( input_node = graph.get_tensor_by_name(
normalize_tf_tensor_name(input_names[i])) normalize_tf_tensor_name(input_names[i]))
input_dict[input_node] = input_value input_dict[input_node] = input_value
...@@ -188,15 +192,20 @@ def validate_tf_model(platform, device_type, model_file, input_file, ...@@ -188,15 +192,20 @@ def validate_tf_model(platform, device_type, model_file, input_file,
output_file_name = common.formatted_file_name( output_file_name = common.formatted_file_name(
mace_out_file, output_names[i]) mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
if output_data_formats[i] == common.DataFormat.NCHW and\
len(output_shapes[i]) == 4:
mace_out_value = mace_out_value.\
reshape(output_shapes[i]).transpose((0, 2, 3, 1))
compare_output(platform, device_type, output_names[i], compare_output(platform, device_type, output_names[i],
mace_out_value, output_values[i], mace_out_value, output_values[i],
validation_threshold, log_file) validation_threshold, log_file)
def validate_caffe_model(platform, device_type, model_file, input_file, def validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file, weight_file, input_names, input_shapes, mace_out_file, weight_file,
output_names, output_shapes, validation_threshold, input_names, input_shapes, input_data_formats,
log_file): output_names, output_shapes, output_data_formats,
validation_threshold, log_file):
os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints
import caffe import caffe
if not os.path.isfile(model_file): if not os.path.isfile(model_file):
...@@ -215,8 +224,10 @@ def validate_caffe_model(platform, device_type, model_file, input_file, ...@@ -215,8 +224,10 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
for i in range(len(input_names)): for i in range(len(input_names)):
input_value = load_data( input_value = load_data(
common.formatted_file_name(input_file, input_names[i])) common.formatted_file_name(input_file, input_names[i]))
input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, input_value = input_value.reshape(input_shapes[i])
2)) if input_data_formats[i] == common.DataFormat.NHWC and \
len(input_shapes[i]) == 4:
input_value = input_value.transpose((0, 3, 1, 2))
input_blob_name = input_names[i] input_blob_name = input_names[i]
try: try:
if input_names[i] in net.top_names: if input_names[i] in net.top_names:
...@@ -232,22 +243,23 @@ def validate_caffe_model(platform, device_type, model_file, input_file, ...@@ -232,22 +243,23 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
for i in range(len(output_names)): for i in range(len(output_names)):
value = net.blobs[output_names[i]].data value = net.blobs[output_names[i]].data
out_shape = output_shapes[i]
if len(out_shape) == 4:
out_shape[1], out_shape[2], out_shape[3] = \
out_shape[3], out_shape[1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = common.formatted_file_name( output_file_name = common.formatted_file_name(
mace_out_file, output_names[i]) mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
if output_data_formats[i] == common.DataFormat.NHWC and \
len(output_shapes[i]) == 4:
mace_out_value = mace_out_value.reshape(output_shapes[i])\
.transpose((0, 3, 1, 2))
compare_output(platform, device_type, output_names[i], mace_out_value, compare_output(platform, device_type, output_names[i], mace_out_value,
value, validation_threshold, log_file) value, validation_threshold, log_file)
def validate_onnx_model(platform, device_type, model_file, input_file, def validate_onnx_model(platform, device_type, model_file,
mace_out_file, input_names, input_shapes, input_file, mace_out_file,
output_names, output_shapes, validation_threshold, input_names, input_shapes, input_data_formats,
input_data_types, backend, log_file): output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types,
backend, log_file):
import onnx import onnx
if backend == "tensorflow": if backend == "tensorflow":
from onnx_tf.backend import prepare from onnx_tf.backend import prepare
...@@ -269,13 +281,16 @@ def validate_onnx_model(platform, device_type, model_file, input_file, ...@@ -269,13 +281,16 @@ def validate_onnx_model(platform, device_type, model_file, input_file,
input_value = load_data(common.formatted_file_name(input_file, input_value = load_data(common.formatted_file_name(input_file,
input_names[i]), input_names[i]),
input_data_types[i]) input_data_types[i])
input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, input_value = input_value.reshape(input_shapes[i])
2)) if input_data_formats[i] == common.DataFormat.NHWC and \
len(input_shapes[i]) == 4:
input_value = input_value.transpose((0, 3, 1, 2))
input_dict[input_names[i]] = input_value input_dict[input_names[i]] = input_value
onnx_outputs = [] onnx_outputs = []
for i in range(len(output_names)): for i in range(len(output_names)):
out_shape = output_shapes[i] out_shape = output_shapes[i]
if len(out_shape) == 4: if output_data_formats[i] == common.DataFormat.NHWC and\
len(out_shape) == 4:
out_shape[1], out_shape[2], out_shape[3] = \ out_shape[1], out_shape[2], out_shape[3] = \
out_shape[3], out_shape[1], out_shape[2] out_shape[3], out_shape[1], out_shape[2]
onnx_outputs.append( onnx_outputs.append(
...@@ -289,25 +304,32 @@ def validate_onnx_model(platform, device_type, model_file, input_file, ...@@ -289,25 +304,32 @@ def validate_onnx_model(platform, device_type, model_file, input_file,
for i in range(len(output_names)): for i in range(len(output_names)):
out_name = output_names[i] out_name = output_names[i]
value = output_values[out_name].flatten() value = output_values[out_name].flatten()
out_shape = output_shapes[i]
if len(out_shape) == 4:
value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = common.formatted_file_name(mace_out_file, output_file_name = common.formatted_file_name(mace_out_file,
output_names[i]) output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
if output_data_formats[i] == common.DataFormat.NHWC and \
len(output_shapes[i]) == 4:
mace_out_value = mace_out_value.reshape(output_shapes[i]) \
.transpose((0, 3, 1, 2))
compare_output(platform, device_type, output_names[i], compare_output(platform, device_type, output_names[i],
mace_out_value, value, mace_out_value, value,
validation_threshold, log_file) validation_threshold, log_file)
def validate(platform, model_file, weight_file, input_file, mace_out_file, def validate(platform, model_file, weight_file, input_file, mace_out_file,
device_type, input_shape, output_shape, input_node, output_node, device_type, input_shape, output_shape, input_data_format_str,
output_data_format_str, input_node, output_node,
validation_threshold, input_data_type, backend, validation_threshold, input_data_type, backend,
validation_outputs_data, log_file): validation_outputs_data, log_file):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')] input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')] input_shapes = [[int(x) for x in shape.split(',')]
for shape in input_shape_strs] for shape in input_shape_strs]
output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')]
for shape in output_shape_strs]
input_data_formats = [df for df in input_data_format_str.split(',')]
output_data_formats = [df for df in output_data_format_str.split(',')]
if input_data_type: if input_data_type:
input_data_types = [data_type input_data_types = [data_type
for data_type in input_data_type.split(',')] for data_type in input_data_type.split(',')]
...@@ -323,32 +345,27 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file, ...@@ -323,32 +345,27 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
else: else:
validation_outputs = validation_outputs_data validation_outputs = validation_outputs_data
if validation_outputs: if validation_outputs:
output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')]
for shape in output_shape_strs]
validate_with_file(platform, device_type, output_names, output_shapes, validate_with_file(platform, device_type, output_names, output_shapes,
mace_out_file, validation_outputs, mace_out_file, validation_outputs,
validation_threshold, log_file) validation_threshold, log_file)
elif platform == 'tensorflow': elif platform == 'tensorflow':
validate_tf_model(platform, device_type, model_file, input_file, validate_tf_model(platform, device_type,
mace_out_file, input_names, input_shapes, model_file, input_file, mace_out_file,
output_names, validation_threshold, input_data_types, input_names, input_shapes, input_data_formats,
output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types,
log_file) log_file)
elif platform == 'caffe': elif platform == 'caffe':
output_shape_strs = [shape for shape in output_shape.split(':')] validate_caffe_model(platform, device_type, model_file,
output_shapes = [[int(x) for x in shape.split(',')] input_file, mace_out_file, weight_file,
for shape in output_shape_strs] input_names, input_shapes, input_data_formats,
validate_caffe_model(platform, device_type, model_file, input_file, output_names, output_shapes, output_data_formats,
mace_out_file, weight_file, input_names,
input_shapes, output_names, output_shapes,
validation_threshold, log_file) validation_threshold, log_file)
elif platform == 'onnx': elif platform == 'onnx':
output_shape_strs = [shape for shape in output_shape.split(':')] validate_onnx_model(platform, device_type, model_file,
output_shapes = [[int(x) for x in shape.split(',')] input_file, mace_out_file,
for shape in output_shape_strs] input_names, input_shapes, input_data_formats,
validate_onnx_model(platform, device_type, model_file, input_file, output_names, output_shapes, output_data_formats,
mace_out_file, input_names, input_shapes,
output_names, output_shapes,
validation_threshold, validation_threshold,
input_data_types, backend, log_file) input_data_types, backend, log_file)
...@@ -379,8 +396,14 @@ def parse_args(): ...@@ -379,8 +396,14 @@ def parse_args():
"--device_type", type=str, default="", help="mace runtime device.") "--device_type", type=str, default="", help="mace runtime device.")
parser.add_argument( parser.add_argument(
"--input_shape", type=str, default="1,64,64,3", help="input shape.") "--input_shape", type=str, default="1,64,64,3", help="input shape.")
parser.add_argument(
"--input_data_format", type=str, default="NHWC",
help="input data format.")
parser.add_argument( parser.add_argument(
"--output_shape", type=str, default="1,64,64,2", help="output shape.") "--output_shape", type=str, default="1,64,64,2", help="output shape.")
parser.add_argument(
"--output_data_format", type=str, default="NHWC",
help="output data format.")
parser.add_argument( parser.add_argument(
"--input_node", type=str, default="input_node", help="input node") "--input_node", type=str, default="input_node", help="input node")
parser.add_argument( parser.add_argument(
...@@ -417,6 +440,8 @@ if __name__ == '__main__': ...@@ -417,6 +440,8 @@ if __name__ == '__main__':
FLAGS.device_type, FLAGS.device_type,
FLAGS.input_shape, FLAGS.input_shape,
FLAGS.output_shape, FLAGS.output_shape,
FLAGS.input_data_format,
FLAGS.output_data_format,
FLAGS.input_node, FLAGS.input_node,
FLAGS.output_node, FLAGS.output_node,
FLAGS.validation_threshold, FLAGS.validation_threshold,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册