提交 c282cb3e 编写于 作者: L liuqi

Feature: DepthwiseConv2D supports non-const filter tensor.

1. input_data_formats flag in yaml supports OIHW
2. MaceTensor support OIHW data format
上级 3e795fa3
...@@ -95,6 +95,8 @@ DataFormat ParseDataFormat(const std::string &data_format_str) { ...@@ -95,6 +95,8 @@ DataFormat ParseDataFormat(const std::string &data_format_str) {
return DataFormat::NHWC; return DataFormat::NHWC;
} else if (data_format_str == "NCHW") { } else if (data_format_str == "NCHW") {
return DataFormat::NCHW; return DataFormat::NCHW;
} else if (data_format_str == "OIHW") {
return DataFormat::OIHW;
} else { } else {
return DataFormat::DF_NONE; return DataFormat::DF_NONE;
} }
......
...@@ -61,21 +61,29 @@ void MemoryOptimizer::UpdateTensorRef(const mace::OperatorDef *op_def) { ...@@ -61,21 +61,29 @@ void MemoryOptimizer::UpdateTensorRef(const mace::OperatorDef *op_def) {
} }
MemoryBlock MemoryOptimizer::CreateMemoryBlock( MemoryBlock MemoryOptimizer::CreateMemoryBlock(
std::vector<int64_t> shape, const OperatorDef *op_def,
int output_idx,
DataType dt, DataType dt,
mace::MemoryType mem_type) { MemoryType mem_type) {
auto shape = std::vector<int64_t>(
op_def->output_shape(output_idx).dims().begin(),
op_def->output_shape(output_idx).dims().end());
MemoryBlock block; MemoryBlock block;
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
if (mem_type == MemoryType::GPU_IMAGE) { if (mem_type == MemoryType::GPU_IMAGE) {
OpenCLBufferType buffer_type = OpenCLBufferType::IN_OUT_CHANNEL;
if (op_def->type() == "BufferTransform") {
buffer_type = static_cast<OpenCLBufferType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op_def, "buffer_type", OpenCLBufferType::IN_OUT_CHANNEL));
}
std::vector<size_t> image_shape; std::vector<size_t> image_shape;
if (shape.size() == 2) { if (shape.size() == 2) {
shape = {shape[0], 1, 1, shape[1]}; shape = {shape[0], 1, 1, shape[1]};
} else { } else {
MACE_CHECK(shape.size() == 4) << "GPU only support 2D/4D input"; MACE_CHECK(shape.size() == 4) << "GPU only support 2D/4D input";
} }
OpenCLUtil::CalImage2DShape(shape, OpenCLUtil::CalImage2DShape(shape, buffer_type, &image_shape);
OpenCLBufferType::IN_OUT_CHANNEL,
&image_shape);
block.set_x(image_shape[0]); block.set_x(image_shape[0]);
block.set_y(image_shape[1]); block.set_y(image_shape[1]);
return block; return block;
...@@ -93,7 +101,7 @@ MemoryBlock MemoryOptimizer::CreateMemoryBlock( ...@@ -93,7 +101,7 @@ MemoryBlock MemoryOptimizer::CreateMemoryBlock(
void MemoryOptimizer::Optimize( void MemoryOptimizer::Optimize(
const mace::OperatorDef *op_def, const mace::OperatorDef *op_def,
const std::unordered_map<std::string, MemoryType> &mem_types) { const std::unordered_map<std::string, MemoryType> *mem_types) {
MACE_LATENCY_LOGGER(2, "Optimize memory"); MACE_LATENCY_LOGGER(2, "Optimize memory");
if (op_def->output_size() != op_def->output_shape_size()) { if (op_def->output_size() != op_def->output_shape_size()) {
VLOG(1) << op_def->name() VLOG(1) << op_def->name()
...@@ -127,22 +135,15 @@ void MemoryOptimizer::Optimize( ...@@ -127,22 +135,15 @@ void MemoryOptimizer::Optimize(
int best_mem_id = -1; int best_mem_id = -1;
MemoryType mem_type = MemoryType::CPU_BUFFER; MemoryType mem_type = MemoryType::CPU_BUFFER;
if (device == DeviceType::GPU) { if (device == DeviceType::GPU) {
mem_type = mem_types.at(op_def->output(i)); mem_type = mem_types->at(op_def->output(i));
} }
auto shape = std::vector<int64_t>( MemoryBlock op_mem_block = CreateMemoryBlock(op_def, i, dt, mem_type);
op_def->output_shape(i).dims().begin(),
op_def->output_shape(i).dims().end());
MemoryBlock op_mem_block = CreateMemoryBlock(shape, dt, mem_type);
MemoryBlock best_mem_block; MemoryBlock best_mem_block;
if (IsMemoryReuseOp(op_def->type())) { if (IsMemoryReuseOp(op_def->type())) {
if (tensor_mem_map_.count(op_def->input(0)) == 1) { if (tensor_mem_map_.count(op_def->input(0)) == 1) {
best_mem_id = tensor_mem_map_.at(op_def->input(0)).mem_id; best_mem_id = tensor_mem_map_.at(op_def->input(0)).mem_id;
} }
} else { } else {
auto shape = std::vector<int64_t>(
op_def->output_shape(i).dims().begin(),
op_def->output_shape(i).dims().end());
int64_t op_mem_size = op_mem_block.x() * op_mem_block.y(); int64_t op_mem_size = op_mem_block.x() * op_mem_block.y();
int64_t best_added_mem_size = LLONG_MAX; int64_t best_added_mem_size = LLONG_MAX;
int64_t best_wasted_mem_size = LLONG_MAX; int64_t best_wasted_mem_size = LLONG_MAX;
......
...@@ -92,8 +92,9 @@ class MemoryOptimizer { ...@@ -92,8 +92,9 @@ class MemoryOptimizer {
static bool IsMemoryReuseOp(const std::string &op_type); static bool IsMemoryReuseOp(const std::string &op_type);
void UpdateTensorRef(const std::string &tensor_name); void UpdateTensorRef(const std::string &tensor_name);
void UpdateTensorRef(const OperatorDef *op_def); void UpdateTensorRef(const OperatorDef *op_def);
void Optimize(const OperatorDef *op_def, void Optimize(
const std::unordered_map<std::string, MemoryType> &mem_types); const OperatorDef *op_def,
const std::unordered_map<std::string, MemoryType> *mem_types = nullptr);
const std::vector<MemoryBlock> &mem_blocks() const; const std::vector<MemoryBlock> &mem_blocks() const;
...@@ -102,9 +103,11 @@ class MemoryOptimizer { ...@@ -102,9 +103,11 @@ class MemoryOptimizer {
std::string DebugInfo() const; std::string DebugInfo() const;
private: private:
MemoryBlock CreateMemoryBlock(std::vector<int64_t> shape, MemoryBlock CreateMemoryBlock(
DataType dt, const OperatorDef *op_def,
MemoryType mem_type); int output_idx,
DataType dt,
MemoryType mem_type);
private: private:
std::unordered_map<std::string, int> tensor_ref_count_; std::unordered_map<std::string, int> tensor_ref_count_;
......
...@@ -38,12 +38,15 @@ namespace { ...@@ -38,12 +38,15 @@ namespace {
struct InternalOutputInfo { struct InternalOutputInfo {
InternalOutputInfo(const MemoryType mem_type, InternalOutputInfo(const MemoryType mem_type,
const DataType dtype, const DataType dtype,
const DataFormat data_format,
const std::vector<index_t> &shape, const std::vector<index_t> &shape,
int op_idx) int op_idx)
: mem_type(mem_type), dtype(dtype), shape(shape), op_idx(op_idx) {} : mem_type(mem_type), dtype(dtype), data_format(data_format),
shape(shape), op_idx(op_idx) {}
MemoryType mem_type; // transformed memory type MemoryType mem_type; // transformed memory type
DataType dtype; DataType dtype;
DataFormat data_format;
std::vector<index_t> shape; // tensor shape std::vector<index_t> shape; // tensor shape
int op_idx; // operation which generate the tensor int op_idx; // operation which generate the tensor
}; };
...@@ -132,13 +135,6 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -132,13 +135,6 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
target_device->cpu_runtime()->policy(), target_device->cpu_runtime()->policy(),
target_device->cpu_runtime()->use_gemmlowp())) { target_device->cpu_runtime()->use_gemmlowp())) {
MACE_LATENCY_LOGGER(1, "Constructing SerialNet"); MACE_LATENCY_LOGGER(1, "Constructing SerialNet");
// output tensor : related information
std::unordered_map<std::string, InternalOutputInfo> output_map;
// used for memory optimization
std::unordered_map<std::string, MemoryType> output_mem_map;
std::unordered_set<std::string> transformed_set;
// add input information
MemoryType target_mem_type;
// quantize model flag // quantize model flag
bool is_quantize_model = IsQuantizedModel(*net_def); bool is_quantize_model = IsQuantizedModel(*net_def);
// Tensor Shape map // Tensor Shape map
...@@ -161,7 +157,6 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -161,7 +157,6 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
bool has_data_format = false; bool has_data_format = false;
if (target_device_->device_type() == DeviceType::CPU) { if (target_device_->device_type() == DeviceType::CPU) {
target_mem_type = MemoryType::CPU_BUFFER;
for (auto &input_info : net_def->input_info()) { for (auto &input_info : net_def->input_info()) {
std::vector<index_t> input_shape = std::vector<index_t> input_shape =
std::vector<index_t>(input_info.dims().begin(), std::vector<index_t>(input_info.dims().begin(),
...@@ -178,26 +173,37 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -178,26 +173,37 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
// NHWC -> NCHW // NHWC -> NCHW
input_shape = input_shape =
TransposeShape<index_t, index_t>(input_shape, {0, 3, 1, 2}); TransposeShape<index_t, index_t>(input_shape, {0, 3, 1, 2});
input_data_format = DataFormat::NCHW;
} }
output_map.emplace(input_info.name(), InternalOutputInfo(
target_mem_type, DataType::DT_FLOAT, input_shape, -1));
} }
} }
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
else { // GPU NOLINT[readability/braces] // output tensor : related information
std::unordered_map<std::string, InternalOutputInfo> output_map;
// used for memory optimization
std::unordered_map<std::string, MemoryType> output_mem_map;
std::unordered_set<std::string> transformed_set;
// add input information
MemoryType target_mem_type;
// default data format of output tensor
DataFormat default_output_df = DataFormat::DF_NONE;
if (target_device_->device_type() == DeviceType::GPU) {
target_mem_type = MemoryType::GPU_BUFFER; target_mem_type = MemoryType::GPU_BUFFER;
for (auto &input_info : net_def->input_info()) { for (auto &input_info : net_def->input_info()) {
has_data_format = static_cast<DataFormat>( DataFormat input_data_format = static_cast<DataFormat>(
input_info.data_format()) == NHWC; input_info.data_format());
has_data_format = input_data_format != DataFormat::DF_NONE;
std::vector<index_t> input_shape = std::vector<index_t> input_shape =
std::vector<index_t>(input_info.dims().begin(), std::vector<index_t>(input_info.dims().begin(),
input_info.dims().end()); input_info.dims().end());
// update tensor shape map // update tensor shape map
tensor_shape_map[input_info.name()] = input_shape; tensor_shape_map[input_info.name()] = input_shape;
output_map.emplace(input_info.name(), InternalOutputInfo( output_map.emplace(input_info.name(), InternalOutputInfo(
target_mem_type, DataType::DT_FLOAT, input_shape, -1)); target_mem_type, DataType::DT_FLOAT, input_data_format,
input_shape, -1));
} }
default_output_df =
has_data_format ? DataFormat::NHWC : DataFormat::DF_NONE;
} }
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
...@@ -242,11 +248,13 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -242,11 +248,13 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
<< output_info.mem_type << " to " << output_info.mem_type << " to "
<< wanted_in_mem_type << wanted_in_mem_type
<< ", from Data Type " << output_info.dtype << " to " << ", from Data Type " << output_info.dtype << " to "
<< wanted_in_dt; << wanted_in_dt << ". with data format "
<< output_info.data_format;
std::string input_name = op_def->input(i); std::string input_name = op_def->input(i);
op_def->set_input(i, t_input_name); op_def->set_input(i, t_input_name);
auto input_shape = output_info.shape; auto input_shape = output_info.shape;
if (output_info.mem_type == MemoryType::CPU_BUFFER && if (output_info.mem_type == MemoryType::CPU_BUFFER &&
output_info.data_format == DataFormat::NCHW &&
input_shape.size() == 4) { input_shape.size() == 4) {
// NCHW -> NHWC // NCHW -> NHWC
input_shape = input_shape =
...@@ -254,8 +262,9 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -254,8 +262,9 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
{0, 2, 3, 1}); {0, 2, 3, 1});
} }
auto transform_op_def = OpenCLUtil::CreateTransformOpDef( auto transform_op_def = OpenCLUtil::CreateTransformOpDef(
input_name, input_shape, t_input_name, input_name, input_shape, t_input_name, wanted_in_dt,
wanted_in_dt, wanted_in_mem_type, has_data_format); construct_context.GetInputOpenCLBufferType(i),
wanted_in_mem_type, has_data_format);
OpConstructContext t_construct_context(ws_); OpConstructContext t_construct_context(ws_);
auto transform_op = CreateOperation( auto transform_op = CreateOperation(
op_registry, op_registry,
...@@ -295,6 +304,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -295,6 +304,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
InternalOutputInfo( InternalOutputInfo(
out_mem_type, out_mem_type,
dt, dt,
default_output_df,
op_def->output_shape().empty() ? op_def->output_shape().empty() ?
std::vector<index_t>() : std::vector<index_t>() :
std::vector<index_t>( std::vector<index_t>(
...@@ -343,6 +353,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -343,6 +353,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
internal_output_info.shape, internal_output_info.shape,
output_info.name(), output_info.name(),
output_info.data_type(), output_info.data_type(),
OpenCLBufferType::IN_OUT_CHANNEL,
target_mem_type, target_mem_type,
output_has_data_format); output_has_data_format);
auto transform_op = CreateOperation( auto transform_op = CreateOperation(
...@@ -366,7 +377,11 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -366,7 +377,11 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
for (auto &op : operators_) { for (auto &op : operators_) {
VLOG(2) << "Operator " << op->debug_def().name() << "<" << op->device_type() VLOG(2) << "Operator " << op->debug_def().name() << "<" << op->device_type()
<< ", " << op->debug_def().type() << ">"; << ", " << op->debug_def().type() << ">";
mem_optimizer->Optimize(op->operator_def().get(), output_mem_map); #ifdef MACE_ENABLE_OPENCL
mem_optimizer->Optimize(op->operator_def().get(), &output_mem_map);
#else
mem_optimizer->Optimize(op->operator_def().get());
#endif // MACE_ENABLE_OPENCL
} }
VLOG(1) << mem_optimizer->DebugInfo(); VLOG(1) << mem_optimizer->DebugInfo();
} }
...@@ -448,7 +463,7 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) { ...@@ -448,7 +463,7 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) {
bool transpose_a = op->GetOptionalArg<bool>("transpose_a", false); bool transpose_a = op->GetOptionalArg<bool>("transpose_a", false);
kernels = op->Input(0)->shape(); kernels = op->Input(0)->shape();
if (transpose_a) { if (transpose_a) {
std::swap(kernels[kernels.size()-2], kernels[kernels.size()-1]); std::swap(kernels[kernels.size() - 2], kernels[kernels.size() - 1]);
} }
} else if (type.compare("FullyConnected") == 0) { } else if (type.compare("FullyConnected") == 0) {
kernels = op->Input(1)->shape(); kernels = op->Input(1)->shape();
...@@ -494,16 +509,16 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) { ...@@ -494,16 +509,16 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) {
Tensor::MappingGuard guard(op->Output(i)); Tensor::MappingGuard guard(op->Output(i));
auto *output_data = op->Output(i)->data<float>(); auto *output_data = op->Output(i)->data<float>();
for (index_t j = 0; j < op->Output(i)->size(); ++j) { for (index_t j = 0; j < op->Output(i)->size(); ++j) {
int index = static_cast<int>((output_data[j] - min_v) / bin_v); int index = static_cast<int>((output_data[j] - min_v) / bin_v);
if (index < 0) if (index < 0)
index = 0; index = 0;
else if (index > bin_size-1) else if (index > bin_size - 1)
index = bin_size-1; index = bin_size - 1;
bin_distribution[index]++; bin_distribution[index]++;
} }
LOG(INFO) << "Tensor range @@" << op->debug_def().output(i) LOG(INFO) << "Tensor range @@" << op->debug_def().output(i)
<< "@@" << min_v << "," << max_v<< "@@" << "@@" << min_v << "," << max_v << "@@"
<< MakeString(bin_distribution); << MakeString(bin_distribution);
} }
} }
} }
......
...@@ -86,6 +86,27 @@ DataType OpConstructContext::GetInputDataType(size_t idx) const { ...@@ -86,6 +86,27 @@ DataType OpConstructContext::GetInputDataType(size_t idx) const {
return input_data_types_[idx]; return input_data_types_[idx];
} }
#ifdef MACE_ENABLE_OPENCL
void OpConstructContext::SetInputOpenCLBufferType(
size_t idx, OpenCLBufferType buffer_type) {
if (input_opencl_buffer_types_.empty()) {
// the default inputs' memory types are same as output memory type.
input_opencl_buffer_types_.resize(operator_def_->input_size(),
OpenCLBufferType::IN_OUT_CHANNEL);
}
MACE_CHECK(idx < input_opencl_buffer_types_.size());
input_opencl_buffer_types_[idx] = buffer_type;
}
OpenCLBufferType OpConstructContext::GetInputOpenCLBufferType(
size_t idx) const {
if (input_opencl_buffer_types_.empty()) {
return OpenCLBufferType::IN_OUT_CHANNEL;
}
MACE_CHECK(idx < input_opencl_buffer_types_.size());
return input_opencl_buffer_types_[idx];
}
#endif // MACE_ENABLE_OPENCL
OpInitContext::OpInitContext(Workspace *ws, Device *device) OpInitContext::OpInitContext(Workspace *ws, Device *device)
: ws_(ws), device_(device) {} : ws_(ws), device_(device) {}
......
...@@ -26,6 +26,9 @@ ...@@ -26,6 +26,9 @@
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/workspace.h" #include "mace/core/workspace.h"
#include "mace/proto/mace.pb.h" #include "mace/proto/mace.pb.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_util.h"
#endif // MACE_ENABLE_OPENCL
namespace mace { namespace mace {
...@@ -72,6 +75,11 @@ class OpConstructContext { ...@@ -72,6 +75,11 @@ class OpConstructContext {
DataType GetInputDataType(size_t idx) const; DataType GetInputDataType(size_t idx) const;
#ifdef MACE_ENABLE_OPENCL
void SetInputOpenCLBufferType(size_t idx, OpenCLBufferType buffer_type);
OpenCLBufferType GetInputOpenCLBufferType(size_t idx) const;
#endif // MACE_ENABLE_OPENCL
private: private:
std::shared_ptr<OperatorDef> operator_def_; std::shared_ptr<OperatorDef> operator_def_;
Workspace *ws_; Workspace *ws_;
...@@ -81,6 +89,9 @@ class OpConstructContext { ...@@ -81,6 +89,9 @@ class OpConstructContext {
std::vector<MemoryType> input_mem_types_; std::vector<MemoryType> input_mem_types_;
std::vector<DataType> input_data_types_; std::vector<DataType> input_data_types_;
MemoryType output_mem_type_; // there is only one output memory type now. MemoryType output_mem_type_; // there is only one output memory type now.
#ifdef MACE_ENABLE_OPENCL
std::vector<OpenCLBufferType> input_opencl_buffer_types_;
#endif // MACE_ENABLE_OPENCL
}; };
// memory_optimizer, device // memory_optimizer, device
......
...@@ -151,6 +151,7 @@ std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef( ...@@ -151,6 +151,7 @@ std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef(
const std::vector<mace::index_t> &input_shape, const std::vector<mace::index_t> &input_shape,
const std::string &output_name, const std::string &output_name,
const mace::DataType dt, const mace::DataType dt,
const OpenCLBufferType buffer_type,
const mace::MemoryType mem_type, const mace::MemoryType mem_type,
bool has_data_format) { bool has_data_format) {
std::unique_ptr<OperatorDef> op(new OperatorDef); std::unique_ptr<OperatorDef> op(new OperatorDef);
...@@ -161,7 +162,7 @@ std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef( ...@@ -161,7 +162,7 @@ std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef(
op->add_output(output_name); op->add_output(output_name);
Argument *arg = op->add_arg(); Argument *arg = op->add_arg();
arg->set_name("buffer_type"); arg->set_name("buffer_type");
arg->set_i(static_cast<int32_t>(OpenCLBufferType::IN_OUT_CHANNEL)); arg->set_i(static_cast<int32_t>(buffer_type));
arg = op->add_arg(); arg = op->add_arg();
arg->set_name("mem_type"); arg->set_name("mem_type");
arg->set_i(static_cast<int32_t>(mem_type)); arg->set_i(static_cast<int32_t>(mem_type));
......
...@@ -48,6 +48,7 @@ class OpenCLUtil { ...@@ -48,6 +48,7 @@ class OpenCLUtil {
const std::vector<mace::index_t> &input_shape, const std::vector<mace::index_t> &input_shape,
const std::string &output_name, const std::string &output_name,
const mace::DataType dt, const mace::DataType dt,
const OpenCLBufferType buffer_type,
const MemoryType mem_type, const MemoryType mem_type,
bool has_data_format); bool has_data_format);
}; };
......
...@@ -97,8 +97,6 @@ inline std::ostream &operator<<(std::ostream &os, unsigned char c) { ...@@ -97,8 +97,6 @@ inline std::ostream &operator<<(std::ostream &os, unsigned char c) {
} }
} // namespace numerical_chars } // namespace numerical_chars
enum FilterDataFormat { HWOI = 100, OIHW = 101, HWIO = 102, OHWI = 103 };
class Tensor { class Tensor {
public: public:
Tensor(Allocator *alloc, DataType type, Tensor(Allocator *alloc, DataType type,
......
...@@ -68,7 +68,7 @@ const Tensor *Workspace::GetTensor(const std::string &name) const { ...@@ -68,7 +68,7 @@ const Tensor *Workspace::GetTensor(const std::string &name) const {
if (tensor_map_.count(name)) { if (tensor_map_.count(name)) {
return tensor_map_.at(name).get(); return tensor_map_.at(name).get();
} else { } else {
LOG(WARNING) << "Tensor " << name << " does not exist."; VLOG(1) << "Tensor " << name << " does not exist.";
} }
return nullptr; return nullptr;
} }
......
...@@ -175,6 +175,8 @@ DataFormat ParseDataFormat(const std::string &data_format_str) { ...@@ -175,6 +175,8 @@ DataFormat ParseDataFormat(const std::string &data_format_str) {
return DataFormat::NHWC; return DataFormat::NHWC;
} else if (data_format_str == "NCHW") { } else if (data_format_str == "NCHW") {
return DataFormat::NCHW; return DataFormat::NCHW;
} else if (data_format_str == "OIHW") {
return DataFormat::OIHW;
} else { } else {
return DataFormat::DF_NONE; return DataFormat::DF_NONE;
} }
......
...@@ -291,6 +291,9 @@ MaceTensor::MaceTensor(const std::vector<int64_t> &shape, ...@@ -291,6 +291,9 @@ MaceTensor::MaceTensor(const std::vector<int64_t> &shape,
std::shared_ptr<float> data, std::shared_ptr<float> data,
const DataFormat format) { const DataFormat format) {
MACE_CHECK_NOTNULL(data.get()); MACE_CHECK_NOTNULL(data.get());
MACE_CHECK(format == DataFormat::NHWC || format == DataFormat::NCHW
|| format == OIHW,
"MACE only support NHWC, NCHW and OIHW formats of input now.");
impl_ = make_unique<MaceTensor::Impl>(); impl_ = make_unique<MaceTensor::Impl>();
impl_->shape = shape; impl_->shape = shape;
impl_->data = data; impl_->data = data;
......
...@@ -24,7 +24,7 @@ namespace ops { ...@@ -24,7 +24,7 @@ namespace ops {
void CalcPaddingAndOutputSize(const index_t *input_shape, void CalcPaddingAndOutputSize(const index_t *input_shape,
const DataFormat input_format, const DataFormat input_format,
const index_t *filter_shape, const index_t *filter_shape,
const FilterDataFormat filter_format, const DataFormat filter_format,
const int *dilations, const int *dilations,
const int *strides, const int *strides,
Padding padding, Padding padding,
...@@ -137,7 +137,7 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC ...@@ -137,7 +137,7 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
void CalcOutputSize(const index_t *input_shape, void CalcOutputSize(const index_t *input_shape,
const DataFormat input_format, const DataFormat input_format,
const index_t *filter_shape, const index_t *filter_shape,
const FilterDataFormat filter_format, const DataFormat filter_format,
const int *padding_size, const int *padding_size,
const int *dilations, const int *dilations,
const int *strides, const int *strides,
......
...@@ -35,7 +35,7 @@ namespace ops { ...@@ -35,7 +35,7 @@ namespace ops {
void CalcPaddingAndOutputSize(const index_t *input_shape, void CalcPaddingAndOutputSize(const index_t *input_shape,
const DataFormat input_format, const DataFormat input_format,
const index_t *filter_shape, const index_t *filter_shape,
const FilterDataFormat filter_format, const DataFormat filter_format,
const int *dilations, const int *dilations,
const int *strides, const int *strides,
Padding padding, Padding padding,
...@@ -61,7 +61,7 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, ...@@ -61,7 +61,7 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape,
void CalcOutputSize(const index_t *input_shape, void CalcOutputSize(const index_t *input_shape,
const DataFormat input_format, const DataFormat input_format,
const index_t *filter_shape, const index_t *filter_shape,
const FilterDataFormat filter_format, const DataFormat filter_format,
const int *padding_size, const int *padding_size,
const int *dilations, const int *dilations,
const int *strides, const int *strides,
......
...@@ -500,13 +500,19 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase { ...@@ -500,13 +500,19 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase {
kernel_ = make_unique<opencl::buffer::DepthwiseConv2dKernel<T>>(); kernel_ = make_unique<opencl::buffer::DepthwiseConv2dKernel<T>>();
} }
context->set_output_mem_type(mem_type); context->set_output_mem_type(mem_type);
// Transform filter tensor to target format Tensor *filter_tensor = context->workspace()->GetTensor(
MACE_CHECK(TransformFilter<T>( operator_def_->input(1));
context, if (filter_tensor != nullptr && filter_tensor->is_weight()) {
operator_def_.get(), // Transform filter tensor to target format
1, MACE_CHECK(TransformFilter<T>(
OpenCLBufferType::DW_CONV2D_FILTER, context,
mem_type) == MaceStatus::MACE_SUCCESS); operator_def_.get(),
1,
OpenCLBufferType::DW_CONV2D_FILTER,
mem_type) == MaceStatus::MACE_SUCCESS);
} else {
context->SetInputOpenCLBufferType(1, OpenCLBufferType::DW_CONV2D_FILTER);
}
if (operator_def_->input_size() > 2) { if (operator_def_->input_size() > 2) {
MACE_CHECK(TransformFilter<T>( MACE_CHECK(TransformFilter<T>(
context, operator_def_.get(), 2, OpenCLBufferType::ARGUMENT, mem_type) context, operator_def_.get(), 2, OpenCLBufferType::ARGUMENT, mem_type)
......
...@@ -259,9 +259,9 @@ class OpsTestNet { ...@@ -259,9 +259,9 @@ class OpsTestNet {
template <DeviceType D, typename T> template <DeviceType D, typename T>
void TransformFilterDataFormat(const std::string &src_name, void TransformFilterDataFormat(const std::string &src_name,
const FilterDataFormat src_format, const DataFormat src_format,
const std::string &dst_name, const std::string &dst_name,
const FilterDataFormat dst_format) { const DataFormat dst_format) {
Tensor *input = ws_.GetTensor(src_name); Tensor *input = ws_.GetTensor(src_name);
Tensor *output = ws_.CreateTensor( Tensor *output = ws_.CreateTensor(
dst_name, dst_name,
......
...@@ -34,7 +34,10 @@ class NetDef; ...@@ -34,7 +34,10 @@ class NetDef;
enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3 }; enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3 };
enum DataFormat { DF_NONE = 0, NHWC = 1, NCHW = 2}; enum DataFormat {
DF_NONE = 0, NHWC = 1, NCHW = 2,
HWOI = 100, OIHW = 101, HWIO = 102, OHWI = 103
};
enum GPUPerfHint { enum GPUPerfHint {
PERF_DEFAULT = 0, PERF_DEFAULT = 0,
......
...@@ -43,6 +43,7 @@ data_format_map = { ...@@ -43,6 +43,7 @@ data_format_map = {
'NONE': cvt.DataFormat.DF_NONE, 'NONE': cvt.DataFormat.DF_NONE,
'NHWC': cvt.DataFormat.NHWC, 'NHWC': cvt.DataFormat.NHWC,
'NCHW': cvt.DataFormat.NCHW, 'NCHW': cvt.DataFormat.NCHW,
'OIHW': cvt.DataFormat.OIHW,
} }
......
...@@ -28,9 +28,6 @@ class DataFormat(Enum): ...@@ -28,9 +28,6 @@ class DataFormat(Enum):
DF_NONE = 0 DF_NONE = 0
NHWC = 1 NHWC = 1
NCHW = 2 NCHW = 2
class FilterFormat(Enum):
HWIO = 100 HWIO = 100
OIHW = 101 OIHW = 101
HWOI = 102 HWOI = 102
...@@ -571,11 +568,11 @@ class ConverterUtil(object): ...@@ -571,11 +568,11 @@ class ConverterUtil(object):
arg = ConverterUtil.get_arg(net, MaceKeyword.mace_filter_format_str) arg = ConverterUtil.get_arg(net, MaceKeyword.mace_filter_format_str)
if arg is None: if arg is None:
return None return None
elif arg.i == FilterFormat.HWIO.value: elif arg.i == DataFormat.HWIO.value:
return FilterFormat.HWIO return DataFormat.HWIO
elif arg.i == FilterFormat.HWOI.value: elif arg.i == DataFormat.HWOI.value:
return FilterFormat.HWOI return DataFormat.HWOI
elif arg.i == FilterFormat.OIHW.value: elif arg.i == DataFormat.OIHW.value:
return FilterFormat.OIHW return DataFormat.OIHW
else: else:
return None return None
...@@ -27,7 +27,6 @@ from mace.python.tools.converter_tool.base_converter import ActivationType ...@@ -27,7 +27,6 @@ from mace.python.tools.converter_tool.base_converter import ActivationType
from mace.python.tools.converter_tool.base_converter import EltwiseType from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import FrameworkType from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import DataFormat from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil from mace.python.tools.converter_tool.base_converter import ConverterUtil
...@@ -194,7 +193,7 @@ class CaffeConverter(base_converter.ConverterInterface): ...@@ -194,7 +193,7 @@ class CaffeConverter(base_converter.ConverterInterface):
} }
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.OIHW) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
self._caffe_net = CaffeNet() self._caffe_net = CaffeNet()
self._caffe_layers = caffe_pb2.NetParameter() self._caffe_layers = caffe_pb2.NetParameter()
caffe_weights = caffe_pb2.NetParameter() caffe_weights = caffe_pb2.NetParameter()
......
...@@ -27,7 +27,6 @@ from mace.python.tools.converter_tool.base_converter import ReduceType ...@@ -27,7 +27,6 @@ from mace.python.tools.converter_tool.base_converter import ReduceType
from mace.python.tools.converter_tool.base_converter import FrameworkType from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import RoundMode from mace.python.tools.converter_tool.base_converter import RoundMode
from mace.python.tools.converter_tool.base_converter import DataFormat from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil from mace.python.tools.converter_tool.base_converter import ConverterUtil
...@@ -370,7 +369,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -370,7 +369,7 @@ class OnnxConverter(base_converter.ConverterInterface):
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
self._data_format = DataFormat.NCHW self._data_format = DataFormat.NCHW
ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.OIHW) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
onnx_model = onnx.load(src_model_file) onnx_model = onnx.load(src_model_file)
ir_version = onnx_model.ir_version ir_version = onnx_model.ir_version
......
...@@ -20,7 +20,6 @@ import six ...@@ -20,7 +20,6 @@ import six
from mace.python.tools.converter_tool.transformer import Transformer from mace.python.tools.converter_tool.transformer import Transformer
from mace.python.tools.converter_tool.base_converter import DataFormat from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil from mace.python.tools.converter_tool.base_converter import ConverterUtil
...@@ -129,7 +128,7 @@ class ShapeInference(object): ...@@ -129,7 +128,7 @@ class ShapeInference(object):
output_shape[0] = input_shape[0] output_shape[0] = input_shape[0]
if ConverterUtil.data_format(op) == DataFormat.NCHW \ if ConverterUtil.data_format(op) == DataFormat.NCHW \
and ConverterUtil.filter_format(self._net) == FilterFormat.OIHW: # noqa and ConverterUtil.filter_format(self._net) == DataFormat.OIHW: # noqa
# filter format: OIHW # filter format: OIHW
if op.type == MaceOp.DepthwiseConv2d.name: if op.type == MaceOp.DepthwiseConv2d.name:
output_shape[1] = filter_shape[0] * filter_shape[1] output_shape[1] = filter_shape[0] * filter_shape[1]
...@@ -170,7 +169,7 @@ class ShapeInference(object): ...@@ -170,7 +169,7 @@ class ShapeInference(object):
MaceKeyword.mace_group_str) MaceKeyword.mace_group_str)
output_shape[0] = input_shape[0] output_shape[0] = input_shape[0]
if ConverterUtil.data_format(op) == DataFormat.NCHW \ if ConverterUtil.data_format(op) == DataFormat.NCHW \
and ConverterUtil.filter_format(self._net) == FilterFormat.OIHW: # noqa and ConverterUtil.filter_format(self._net) == DataFormat.OIHW: # noqa
# filter format: IOHW # filter format: IOHW
output_shape[1] = filter_shape[1] output_shape[1] = filter_shape[1]
if group_arg is not None and group_arg.i > 1: if group_arg is not None and group_arg.i > 1:
......
...@@ -29,7 +29,6 @@ from mace.python.tools.converter_tool.base_converter import PadType ...@@ -29,7 +29,6 @@ from mace.python.tools.converter_tool.base_converter import PadType
from mace.python.tools.converter_tool.base_converter import FrameworkType from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import ReduceType from mace.python.tools.converter_tool.base_converter import ReduceType
from mace.python.tools.converter_tool.base_converter import DataFormat from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil from mace.python.tools.converter_tool.base_converter import ConverterUtil
...@@ -280,7 +279,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -280,7 +279,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
} }
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.HWIO) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO)
# import tensorflow graph # import tensorflow graph
tf_graph_def = tf.GraphDef() tf_graph_def = tf.GraphDef()
...@@ -347,13 +346,19 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -347,13 +346,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
for input_node in self._option.input_nodes.values(): for input_node in self._option.input_nodes.values():
if node.name == input_node.name \ if node.name == input_node.name \
or node.name + ':0' == input_node.name: or node.name + ':0' == input_node.name:
input_shape = input_node.shape
if input_node.data_format == DataFormat.OIHW \
and len(input_shape) == 4:
# OIHW -> HWIO
input_shape = [input_shape[2], input_shape[3],
input_shape[1], input_shape[0]]
del node.attr['shape'].shape.dim[:] del node.attr['shape'].shape.dim[:]
node.attr['shape'].shape.dim.extend([ node.attr['shape'].shape.dim.extend([
tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in
input_node.shape input_shape
]) ])
self._placeholders[node.name + ':0'] = \ self._placeholders[node.name + ':0'] = \
np.zeros(shape=input_node.shape, dtype=float) np.zeros(shape=input_shape, dtype=float)
@staticmethod @staticmethod
def get_scope(tensor_name): def get_scope(tensor_name):
......
...@@ -25,7 +25,6 @@ from mace.python.tools.converter_tool.base_converter import DataFormat ...@@ -25,7 +25,6 @@ from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import DeviceType from mace.python.tools.converter_tool.base_converter import DeviceType
from mace.python.tools.converter_tool.base_converter import EltwiseType from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import FrameworkType from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import MaceOp from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import PaddingMode from mace.python.tools.converter_tool.base_converter import PaddingMode
...@@ -149,12 +148,12 @@ class Transformer(base_converter.ConverterInterface): ...@@ -149,12 +148,12 @@ class Transformer(base_converter.ConverterInterface):
filter_format_value = ConverterUtil.get_arg(self._model, filter_format_value = ConverterUtil.get_arg(self._model,
MaceKeyword.mace_filter_format_str).i # noqa MaceKeyword.mace_filter_format_str).i # noqa
filter_format = None filter_format = None
if filter_format_value == FilterFormat.HWIO.value: if filter_format_value == DataFormat.HWIO.value:
filter_format = FilterFormat.HWIO filter_format = DataFormat.HWIO
elif filter_format_value == FilterFormat.OIHW.value: elif filter_format_value == DataFormat.OIHW.value:
filter_format = FilterFormat.OIHW filter_format = DataFormat.OIHW
elif filter_format_value == FilterFormat.HWOI.value: elif filter_format_value == DataFormat.HWOI.value:
filter_format = FilterFormat.HWOI filter_format = DataFormat.HWOI
else: else:
mace_check(False, "filter format %d not supported" % mace_check(False, "filter format %d not supported" %
filter_format_value) filter_format_value)
...@@ -614,14 +613,14 @@ class Transformer(base_converter.ConverterInterface): ...@@ -614,14 +613,14 @@ class Transformer(base_converter.ConverterInterface):
offset = self._consts[consumer_op.input[2]] offset = self._consts[consumer_op.input[2]]
idx = 0 idx = 0
filter_format = self.filter_format() filter_format = self.filter_format()
if filter_format == FilterFormat.HWIO: if filter_format == DataFormat.HWIO:
for hwi in six.moves.range(filter.dims[0] for hwi in six.moves.range(filter.dims[0]
* filter.dims[1] * filter.dims[1]
* filter.dims[2]): * filter.dims[2]):
for o in six.moves.range(filter.dims[3]): for o in six.moves.range(filter.dims[3]):
filter.float_data[idx] *= scale.float_data[o] filter.float_data[idx] *= scale.float_data[o]
idx += 1 idx += 1
elif filter_format == FilterFormat.OIHW: elif filter_format == DataFormat.OIHW:
for o in six.moves.range(filter.dims[0]): for o in six.moves.range(filter.dims[0]):
for hwi in six.moves.range(filter.dims[1] for hwi in six.moves.range(filter.dims[1]
* filter.dims[2] * filter.dims[2]
...@@ -673,7 +672,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -673,7 +672,7 @@ class Transformer(base_converter.ConverterInterface):
idx = 0 idx = 0
filter_format = self.filter_format() filter_format = self.filter_format()
# in deconv op O and I channel is switched # in deconv op O and I channel is switched
if filter_format == FilterFormat.HWIO: if filter_format == DataFormat.HWIO:
for hw in six.moves.range(filter.dims[0] for hw in six.moves.range(filter.dims[0]
* filter.dims[1]): * filter.dims[1]):
for o in six.moves.range(filter.dims[2]): for o in six.moves.range(filter.dims[2]):
...@@ -681,7 +680,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -681,7 +680,7 @@ class Transformer(base_converter.ConverterInterface):
filter.float_data[idx] *=\ filter.float_data[idx] *=\
scale.float_data[o] scale.float_data[o]
idx += 1 idx += 1
elif filter_format == FilterFormat.OIHW: elif filter_format == DataFormat.OIHW:
for i in six.moves.range(filter.dims[0]): for i in six.moves.range(filter.dims[0]):
for o in six.moves.range(filter.dims[1]): for o in six.moves.range(filter.dims[1]):
for hw in six.moves.range(filter.dims[2] for hw in six.moves.range(filter.dims[2]
...@@ -736,7 +735,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -736,7 +735,7 @@ class Transformer(base_converter.ConverterInterface):
idx = 0 idx = 0
filter_format = self.filter_format() filter_format = self.filter_format()
if filter_format == FilterFormat.HWIO: if filter_format == DataFormat.HWIO:
for hw in six.moves.range(filter.dims[0] for hw in six.moves.range(filter.dims[0]
* filter.dims[1]): * filter.dims[1]):
for i in six.moves.range(filter.dims[2]): for i in six.moves.range(filter.dims[2]):
...@@ -744,7 +743,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -744,7 +743,7 @@ class Transformer(base_converter.ConverterInterface):
filter.float_data[idx] *= scale.float_data[ filter.float_data[idx] *= scale.float_data[
i * filter.dims[3] + o] i * filter.dims[3] + o]
idx += 1 idx += 1
elif filter_format == FilterFormat.OIHW: elif filter_format == DataFormat.OIHW:
for o in six.moves.range(filter.dims[0]): for o in six.moves.range(filter.dims[0]):
for i in six.moves.range(filter.dims[1]): for i in six.moves.range(filter.dims[1]):
for hw in six.moves.range(filter.dims[2] for hw in six.moves.range(filter.dims[2]
...@@ -791,17 +790,17 @@ class Transformer(base_converter.ConverterInterface): ...@@ -791,17 +790,17 @@ class Transformer(base_converter.ConverterInterface):
@staticmethod @staticmethod
def sort_filter_shape(filter_shape, filter_format): def sort_filter_shape(filter_shape, filter_format):
"""Return filter shape in HWIO order""" """Return filter shape in HWIO order"""
if filter_format == FilterFormat.HWIO: if filter_format == DataFormat.HWIO:
filter_height = filter_shape[0] filter_height = filter_shape[0]
filter_width = filter_shape[1] filter_width = filter_shape[1]
in_channels = filter_shape[2] in_channels = filter_shape[2]
out_channels = filter_shape[3] out_channels = filter_shape[3]
elif filter_format == FilterFormat.OIHW: elif filter_format == DataFormat.OIHW:
filter_height = filter_shape[2] filter_height = filter_shape[2]
filter_width = filter_shape[3] filter_width = filter_shape[3]
in_channels = filter_shape[1] in_channels = filter_shape[1]
out_channels = filter_shape[0] out_channels = filter_shape[0]
elif filter_format == FilterFormat.HWOI: elif filter_format == DataFormat.HWOI:
filter_height = filter_shape[0] filter_height = filter_shape[0]
filter_width = filter_shape[1] filter_width = filter_shape[1]
in_channels = filter_shape[3] in_channels = filter_shape[3]
...@@ -1006,9 +1005,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1006,9 +1005,9 @@ class Transformer(base_converter.ConverterInterface):
input_shape = list(input_op.output_shape[0].dims) input_shape = list(input_op.output_shape[0].dims)
weight.dims[:] = [weight.dims[0]] + input_shape[1:] weight.dims[:] = [weight.dims[0]] + input_shape[1:]
if len(input_shape) == 2: if len(input_shape) == 2:
if filter_format == FilterFormat.HWIO: if filter_format == DataFormat.HWIO:
weight.dims[:] = [1, 1] + weight.dims[:] weight.dims[:] = [1, 1] + weight.dims[:]
elif filter_format == FilterFormat.OIHW: elif filter_format == DataFormat.OIHW:
weight.dims[:] = weight.dims[:] + [1, 1] weight.dims[:] = weight.dims[:] + [1, 1]
else: else:
mace_check("FC does not support filter format %s", mace_check("FC does not support filter format %s",
...@@ -1141,9 +1140,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1141,9 +1140,9 @@ class Transformer(base_converter.ConverterInterface):
if self._option.quantize and \ if self._option.quantize and \
self._option.device == DeviceType.CPU.value: self._option.device == DeviceType.CPU.value:
print("Transpose filters to OHWI") print("Transpose filters to OHWI")
if filter_format == FilterFormat.HWIO: if filter_format == DataFormat.HWIO:
transpose_order = [3, 0, 1, 2] transpose_order = [3, 0, 1, 2]
elif filter_format == FilterFormat.OIHW: elif filter_format == DataFormat.OIHW:
transpose_order = [0, 2, 3, 1] transpose_order = [0, 2, 3, 1]
else: else:
mace_check("Quantize model does not support conv " mace_check("Quantize model does not support conv "
...@@ -1172,20 +1171,21 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1172,20 +1171,21 @@ class Transformer(base_converter.ConverterInterface):
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
transposed_deconv_filter.add(op.input[1]) transposed_deconv_filter.add(op.input[1])
self.set_filter_format(FilterFormat.OHWI) self.set_filter_format(DataFormat.OHWI)
elif self._option.quantize and \ elif self._option.quantize and \
self._option.device == DeviceType.HEXAGON.value: self._option.device == DeviceType.HEXAGON.value:
print("Transpose filters to HWIO/HWIM") print("Transpose filters to HWIO/HWIM")
mace_check(filter_format == FilterFormat.HWIO, mace_check(filter_format == DataFormat.HWIO,
"HEXAGON only support HWIO/HWIM filter format.") "HEXAGON only support HWIO/HWIM filter format.")
else: else:
print("Transpose filters to OIHW/MIHW") print("Transpose filters to OIHW/MIHW")
# transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM) # transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM)
if filter_format == FilterFormat.HWIO: if filter_format == DataFormat.HWIO:
for op in net.op: for op in net.op:
if (op.type == MaceOp.Conv2D.name if (op.type == MaceOp.Conv2D.name
or op.type == MaceOp.Deconv2D.name or op.type == MaceOp.Deconv2D.name
or op.type == MaceOp.DepthwiseConv2d.name) \ or op.type == MaceOp.DepthwiseConv2d.name) \
and op.input[1] in self._consts \
and op.input[1] not in transposed_filter: and op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape( filter_data = np.array(filter.float_data).reshape(
...@@ -1215,7 +1215,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1215,7 +1215,7 @@ class Transformer(base_converter.ConverterInterface):
weight.dims[:] = weight_data.shape weight.dims[:] = weight_data.shape
transposed_filter.add(op.input[1]) transposed_filter.add(op.input[1])
self.set_filter_format(FilterFormat.OIHW) self.set_filter_format(DataFormat.OIHW)
# deconv's filter's output channel and input channel is reversed # deconv's filter's output channel and input channel is reversed
for op in net.op: for op in net.op:
if op.type in [MaceOp.Deconv2D.name, if op.type in [MaceOp.Deconv2D.name,
...@@ -1296,7 +1296,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1296,7 +1296,7 @@ class Transformer(base_converter.ConverterInterface):
len(op.input) == 2 and \ len(op.input) == 2 and \
op.input[1] in self._consts and \ op.input[1] in self._consts and \
len(op.output_shape[0].dims) == 2 and \ len(op.output_shape[0].dims) == 2 and \
filter_format == FilterFormat.HWIO and \ filter_format == DataFormat.HWIO and \
op.input[0] in self._producer: op.input[0] in self._producer:
input_op = self._producer[op.input[0]] input_op = self._producer[op.input[0]]
input_shape = input_op.output_shape[0].dims input_shape = input_op.output_shape[0].dims
...@@ -1329,7 +1329,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1329,7 +1329,7 @@ class Transformer(base_converter.ConverterInterface):
# transform `fc1(2D) -> matmul` to `fc1(2D) -> fc1(2D)` # transform `fc1(2D) -> matmul` to `fc1(2D) -> fc1(2D)`
if op.type == MaceOp.MatMul.name and \ if op.type == MaceOp.MatMul.name and \
filter_format == FilterFormat.HWIO and \ filter_format == DataFormat.HWIO and \
op.input[1] in self._consts: op.input[1] in self._consts:
producer = self._producer[op.input[0]] producer = self._producer[op.input[0]]
weight = self._consts[op.input[1]] weight = self._consts[op.input[1]]
......
...@@ -108,6 +108,8 @@ DataFormat ParseDataFormat(const std::string &data_format_str) { ...@@ -108,6 +108,8 @@ DataFormat ParseDataFormat(const std::string &data_format_str) {
return DataFormat::NHWC; return DataFormat::NHWC;
} else if (data_format_str == "NCHW") { } else if (data_format_str == "NCHW") {
return DataFormat::NCHW; return DataFormat::NCHW;
} else if (data_format_str == "OIHW") {
return DataFormat::OIHW;
} else { } else {
return DataFormat::DF_NONE; return DataFormat::DF_NONE;
} }
......
...@@ -135,6 +135,7 @@ class DataFormat(object): ...@@ -135,6 +135,7 @@ class DataFormat(object):
NONE = "NONE" NONE = "NONE"
NHWC = "NHWC" NHWC = "NHWC"
NCHW = "NCHW" NCHW = "NCHW"
OIHW = "OIHW"
################################ ################################
......
...@@ -97,6 +97,7 @@ DataFormatStrs = [ ...@@ -97,6 +97,7 @@ DataFormatStrs = [
"NONE", "NONE",
"NHWC", "NHWC",
"NCHW", "NCHW",
"OIHW",
] ]
......
...@@ -178,6 +178,10 @@ def validate_tf_model(platform, device_type, model_file, ...@@ -178,6 +178,10 @@ def validate_tf_model(platform, device_type, model_file,
if input_data_formats[i] == common.DataFormat.NCHW and\ if input_data_formats[i] == common.DataFormat.NCHW and\
len(input_shapes[i]) == 4: len(input_shapes[i]) == 4:
input_value = input_value.transpose((0, 2, 3, 1)) input_value = input_value.transpose((0, 2, 3, 1))
elif input_data_formats[i] == common.DataFormat.OIHW and \
len(input_shapes[i]) == 4:
# OIHW -> HWIO
input_value = input_value.transpose((2, 3, 1, 0))
input_node = graph.get_tensor_by_name( input_node = graph.get_tensor_by_name(
normalize_tf_tensor_name(input_names[i])) normalize_tf_tensor_name(input_names[i]))
input_dict[input_node] = input_value input_dict[input_node] = input_value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册