From 74dcd617b5cd13d2d13bca86e9201e4d3174ffd1 Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 17 Apr 2019 19:43:49 +0800 Subject: [PATCH] Refactor: Support auto transformation and net optimization. 1. Auto transformation between data format, data type and memory type. 2. Add device placement optimization strategy. --- mace/core/arg_helper.cc | 38 + mace/core/arg_helper.h | 12 + mace/core/memory_optimizer.cc | 9 +- mace/core/memory_optimizer.h | 7 +- mace/core/net.cc | 327 +-------- mace/core/net.h | 8 - mace/core/net_def_adapter.cc | 651 ++++++++++++++++++ mace/core/net_def_adapter.h | 110 +++ mace/core/net_optimizer.cc | 50 ++ mace/core/net_optimizer.h | 35 + mace/core/operator.cc | 111 +-- mace/core/operator.h | 73 +- mace/core/runtime/opencl/opencl_util.cc | 30 +- mace/core/runtime/opencl/opencl_util.h | 5 +- mace/core/workspace.cc | 14 +- mace/libmace/mace.cc | 47 +- mace/ops/activation.cc | 18 + mace/ops/addn.cc | 16 + mace/ops/bias_add.cc | 16 + mace/ops/buffer_transform.cc | 6 +- mace/ops/channel_shuffle.cc | 4 +- mace/ops/concat.cc | 8 +- mace/ops/conv_2d.cc | 1 - mace/ops/crop.cc | 16 + mace/ops/deconv_2d.cc | 25 +- mace/ops/depthwise_conv2d.cc | 24 +- mace/ops/expand_dims.cc | 24 +- mace/ops/matmul.cc | 8 - mace/ops/opencl/buffer_transformer.h | 65 +- mace/ops/opencl/image/eltwise.h | 11 +- mace/ops/opencl/image/reduce.h | 5 - mace/ops/pooling.cc | 1 - mace/ops/reduce.cc | 26 + mace/ops/scalar_math.cc | 6 +- mace/ops/softmax.cc | 3 +- mace/ops/split.cc | 4 +- mace/ops/squeeze.cc | 2 +- mace/public/mace.h | 3 +- .../tools/converter_tool/base_converter.py | 39 +- .../tools/converter_tool/caffe_converter.py | 1 + .../tools/converter_tool/onnx_converter.py | 1 + .../converter_tool/tensorflow_converter.py | 1 + .../tools/converter_tool/transformer.py | 265 +++---- 43 files changed, 1456 insertions(+), 670 deletions(-) create mode 100644 mace/core/net_def_adapter.cc create mode 100644 mace/core/net_def_adapter.h create mode 100644 mace/core/net_optimizer.cc create mode 100644 mace/core/net_optimizer.h diff --git a/mace/core/arg_helper.cc b/mace/core/arg_helper.cc index 4f6045d8..f2a6467b 100644 --- a/mace/core/arg_helper.cc +++ b/mace/core/arg_helper.cc @@ -96,6 +96,44 @@ MACE_GET_REPEATED_ARGUMENT_FUNC(int, ints, true) MACE_GET_REPEATED_ARGUMENT_FUNC(int64_t, ints, true) #undef MACE_GET_REPEATED_ARGUMENT_FUNC +#define MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, T, fieldname) \ + template<> \ + void SetProtoArg(Def *def, \ + const std::string &arg_name, \ + const T &value) { \ + int size = def->arg_size(); \ + for (int i = 0; i < size; ++i) { \ + auto arg = def->mutable_arg(i); \ + if (arg->name() == arg_name) { \ + VLOG(3) << "Update old argument value from " \ + << arg->fieldname() << " to " \ + << value << " for " << arg_name; \ + arg->set_##fieldname(value); \ + return; \ + } \ + } \ + VLOG(3) << "Add new argument " << arg_name << "(name: " \ + << arg_name << ", value: " << value << ")"; \ + auto arg = def->add_arg(); \ + arg->set_name(arg_name); \ + arg->set_##fieldname(value); \ + } + +#define MACE_SET_OPTIONAL_ARGUMENT_FUNC_MACRO(Def) \ + MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, float, f) \ + MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, bool, i) \ + MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, int, i) \ + MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, int64_t, i) \ + MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, std::string, s) + +MACE_SET_OPTIONAL_ARGUMENT_FUNC_MACRO(OperatorDef) +MACE_SET_OPTIONAL_ARGUMENT_FUNC_MACRO(NetDef) +#undef MACE_SET_OPTIONAL_ARGUMENT_FUNC + +std::string OutputMemoryTypeTagName() { + static const char *kOutputMemTypeArgName = "output_mem_type"; + return kOutputMemTypeArgName; +} bool IsQuantizedModel(const NetDef &net_def) { return diff --git a/mace/core/arg_helper.h b/mace/core/arg_helper.h index 9d2cd243..5512fb06 100644 --- a/mace/core/arg_helper.h +++ b/mace/core/arg_helper.h @@ -55,6 +55,18 @@ class ProtoArgHelper { std::map arg_map_; }; +template +void SetProtoArg(OperatorDef *op_def, + const std::string &arg_name, + const T&value); + +template +void SetProtoArg(NetDef *op_def, + const std::string &arg_name, + const T&value); + +std::string OutputMemoryTypeTagName(); + bool IsQuantizedModel(const NetDef &def); } // namespace mace diff --git a/mace/core/memory_optimizer.cc b/mace/core/memory_optimizer.cc index 7f86d0eb..9b572071 100644 --- a/mace/core/memory_optimizer.cc +++ b/mace/core/memory_optimizer.cc @@ -33,7 +33,7 @@ namespace mace { bool MemoryOptimizer::IsMemoryReuseOp(const std::string &op_type) { static const std::unordered_set kReuseOp = { - "Reshape", "Identity", "Squeeze" + "Reshape", "Identity", "Squeeze", "ExpandDims" }; return kReuseOp.count(op_type) == 1; } @@ -124,8 +124,9 @@ void MemoryOptimizer::Optimize( op_def->output_type_size()); DataType dt; - bool has_data_format = ProtoArgHelper::GetOptionalArg( - *op_def, "has_data_format", 0) != 0; + DataFormat data_format = static_cast( + ProtoArgHelper::GetOptionalArg( + *op_def, "data_format", DataFormat::DF_NONE)); int output_size = op_def->output_size(); for (int i = 0; i < output_size; ++i) { if (i < op_def->output_type_size()) { @@ -209,7 +210,7 @@ void MemoryOptimizer::Optimize( mem_ref_count_[best_mem_id] = 1; } tensor_mem_map_.emplace(op_def->output(i), TensorMemInfo(best_mem_id, - dt, has_data_format)); + dt, data_format)); } } diff --git a/mace/core/memory_optimizer.h b/mace/core/memory_optimizer.h index 986c5450..b4e635f5 100644 --- a/mace/core/memory_optimizer.h +++ b/mace/core/memory_optimizer.h @@ -22,6 +22,7 @@ #include #include "mace/proto/mace.pb.h" +#include "mace/port/port.h" #include "mace/core/types.h" namespace mace { @@ -81,10 +82,10 @@ class MemoryOptimizer { struct TensorMemInfo { int mem_id; DataType data_type; - bool has_data_format; + DataFormat data_format; - TensorMemInfo(int mem_id, DataType data_type, bool has_data_format) : - mem_id(mem_id), data_type(data_type), has_data_format(has_data_format) + TensorMemInfo(int mem_id, DataType data_type, DataFormat data_format) : + mem_id(mem_id), data_type(data_type), data_format(data_format) {} }; diff --git a/mace/core/net.cc b/mace/core/net.cc index a10d96bb..c6e676d2 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -31,99 +31,8 @@ #include "mace/utils/memory.h" #include "mace/utils/timer.h" -#ifdef MACE_ENABLE_OPENCL -#include "mace/core/runtime/opencl/opencl_util.h" -#endif // MACE_ENABLE_OPENCL - namespace mace { -namespace { -struct InternalOutputInfo { - InternalOutputInfo(const MemoryType mem_type, - const DataType dtype, - const DataFormat data_format, - const std::vector &shape, - int op_idx) - : mem_type(mem_type), dtype(dtype), data_format(data_format), - shape(shape), op_idx(op_idx) {} - - MemoryType mem_type; // transformed memory type - DataType dtype; - DataFormat data_format; - std::vector shape; // tensor shape - int op_idx; // operation which generate the tensor -}; - -#ifdef MACE_ENABLE_OPENCL -std::string TransformedName(const std::string &input_name, - const mace::MemoryType mem_type) { - std::stringstream ss; - ss << input_name << "_mem_type_" << mem_type; - return ss.str(); -} - -bool TransformRequiredOp(const std::string &op_type) { - static const std::unordered_set kNoTransformOp = { - "Shape", "InferConv2dShape" - }; - return kNoTransformOp.count(op_type) == 0; -} -#endif // MACE_ENABLE_OPENCL - -} // namespace - -std::unique_ptr SerialNet::CreateOperation( - const OpRegistryBase *op_registry, - OpConstructContext *construct_context, - std::shared_ptr op_def, - bool has_data_format, - bool is_quantize_model) { - // Create the Operation - DeviceType target_device_type = target_device_->device_type(); - DeviceType device_type = DeviceType::CPU; - construct_context->set_device(cpu_device_.get()); - construct_context->set_operator_def(op_def); - construct_context->set_output_mem_type(MemoryType::CPU_BUFFER); - // Get available devices - auto available_devices = - op_registry->AvailableDevices(op_def->type(), construct_context); - // Find the device type to run the op. - // If the target_device_type in available devices, use target_device_type, - // otherwise, fallback to CPU device. - for (auto device : available_devices) { - if (device == target_device_type) { - device_type = target_device_type; - construct_context->set_device(target_device_); - if (target_device_->device_type() == DeviceType::GPU) { - construct_context->set_output_mem_type(MemoryType::GPU_IMAGE); - } - break; - } - } - op_def->set_device_type(device_type); - - // transpose output shape if run on CPU (default format is NHWC) - if (!is_quantize_model && device_type == DeviceType::CPU && - op_def->output_shape_size() == op_def->output_size()) { - for (int out_idx = 0; out_idx < op_def->output_size(); ++out_idx) { - if (has_data_format && op_def->output_shape(out_idx).dims_size() == 4) { - // NHWC -> NCHW - std::vector output_shape = - TransposeShape( - std::vector( - op_def->output_shape(out_idx).dims().begin(), - op_def->output_shape(out_idx).dims().end()), - {0, 3, 1, 2}); - for (int i = 0; i < 4; ++i) { - op_def->mutable_output_shape(out_idx)->set_dims(i, output_shape[i]); - } - } - } - } - - return op_registry->CreateOperation(construct_context, device_type); -} - SerialNet::SerialNet(const OpRegistryBase *op_registry, const NetDef *net_def, Workspace *ws, @@ -138,237 +47,47 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, target_device->cpu_runtime()->policy(), &target_device->cpu_runtime()->thread_pool())) { MACE_LATENCY_LOGGER(1, "Constructing SerialNet"); - // quantize model flag - bool is_quantize_model = IsQuantizedModel(*net_def); - // Tensor Shape map - std::unordered_map> tensor_shape_map; - for (auto &op : net_def->op()) { - if (op.output_size() != op.output_shape_size()) { - continue; - } - for (int i = 0; i < op.output_size(); ++i) { - tensor_shape_map[op.output(i)] = std::vector( - op.output_shape(i).dims().begin(), - op.output_shape(i).dims().end()); - } - } - for (auto &tensor : net_def->tensors()) { - tensor_shape_map[tensor.name()] = - std::vector(tensor.dims().begin(), tensor.dims().end()); - } - bool has_data_format = false; - if (target_device_->device_type() == DeviceType::CPU) { - for (auto &input_info : net_def->input_info()) { - std::vector input_shape = - std::vector(input_info.dims().begin(), - input_info.dims().end()); - // update tensor shape map - tensor_shape_map[input_info.name()] = input_shape; - // Only could be NONE or NHWC - DataFormat input_data_format = static_cast( - input_info.data_format()); - has_data_format = has_data_format || - (input_data_format != DataFormat::DF_NONE); - if (!is_quantize_model && input_data_format == DataFormat::NHWC && - input_info.dims_size() == 4) { - // NHWC -> NCHW - input_shape = - TransposeShape(input_shape, {0, 3, 1, 2}); - } - } - } #ifdef MACE_ENABLE_OPENCL - // output tensor : related information - std::unordered_map output_map; // used for memory optimization std::unordered_map output_mem_map; - std::unordered_set transformed_set; - // add input information - MemoryType target_mem_type; - // default data format of output tensor - DataFormat default_output_df = DataFormat::DF_NONE; - if (target_device_->device_type() == DeviceType::GPU) { - target_mem_type = MemoryType::GPU_BUFFER; - for (auto &input_info : net_def->input_info()) { - DataFormat input_data_format = static_cast( - input_info.data_format()); - has_data_format = input_data_format != DataFormat::DF_NONE; - std::vector input_shape = - std::vector(input_info.dims().begin(), - input_info.dims().end()); - // update tensor shape map - tensor_shape_map[input_info.name()] = input_shape; - output_map.emplace(input_info.name(), InternalOutputInfo( - target_mem_type, DataType::DT_FLOAT, input_data_format, - input_shape, -1)); - } - default_output_df = - has_data_format ? DataFormat::NHWC : DataFormat::DF_NONE; - } #endif // MACE_ENABLE_OPENCL - OpConstructContext construct_context(ws_, &tensor_shape_map); + OpConstructContext construct_context(ws_); for (int idx = 0; idx < net_def->op_size(); ++idx) { std::shared_ptr op_def(new OperatorDef(net_def->op(idx))); // Create operation - auto op = CreateOperation(op_registry, - &construct_context, - op_def, - has_data_format, - is_quantize_model); -#ifdef MACE_ENABLE_OPENCL - // Add input transform operation if necessary - if (target_device_->device_type() == DeviceType::GPU) { - // the outputs' memory type of the operation - MemoryType out_mem_type = construct_context.output_mem_type(); - int input_size = op_def->input_size(); - // if op is memory-unused op, no transformation - if (TransformRequiredOp(op_def->type())) { - for (int i = 0; i < input_size; ++i) { - if (output_map.count(op_def->input(i)) == 1) { - // if op is memory-reuse op, no transformation - if (MemoryOptimizer::IsMemoryReuseOp(op_def->type())) { - out_mem_type = output_map.at(op_def->input(i)).mem_type; - break; - } - // check whether to do transform - MemoryType wanted_in_mem_type = - construct_context.GetInputMemType(i); - DataType wanted_in_dt = construct_context.GetInputDataType(i); - if (output_map.at(op_def->input(i)).mem_type != wanted_in_mem_type - || output_map.at(op_def->input(i)).dtype != wanted_in_dt) { - auto t_input_name = TransformedName(op_def->input(i), - wanted_in_mem_type); - auto &output_info = output_map.at(op_def->input(i)); - // check whether the tensor has been transformed - if (transformed_set.count(t_input_name) == 0) { - VLOG(1) << "Add Transform operation " << op_def->name() - << " to transform tensor " - << op_def->input(i) << "', from memory type " - << output_info.mem_type << " to " - << wanted_in_mem_type - << ", from Data Type " << output_info.dtype << " to " - << wanted_in_dt << ". with data format " - << output_info.data_format; - std::string input_name = op_def->input(i); - op_def->set_input(i, t_input_name); - auto input_shape = output_info.shape; - if (output_info.mem_type == MemoryType::CPU_BUFFER && - output_info.data_format == DataFormat::NCHW && - input_shape.size() == 4) { - // NCHW -> NHWC - input_shape = - TransposeShape(input_shape, - {0, 2, 3, 1}); - } - auto transform_op_def = OpenCLUtil::CreateTransformOpDef( - input_name, input_shape, t_input_name, wanted_in_dt, - construct_context.GetInputOpenCLBufferType(i), - wanted_in_mem_type, has_data_format); - OpConstructContext t_construct_context(ws_); - auto transform_op = CreateOperation( - op_registry, - &t_construct_context, - transform_op_def, - has_data_format); - operators_.emplace_back(std::move(transform_op)); - transformed_set.insert(t_input_name); - output_mem_map[t_input_name] = wanted_in_mem_type; - // where to do graph reference count. - mem_optimizer->UpdateTensorRef(transform_op_def.get()); - } else { - op_def->set_input(i, t_input_name); - } - } - } else { - MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr - && ws_->GetTensor(op_def->input(i))->is_weight(), - "Tensor ", op_def->input(i), " of ", - op_def->name(), " not allocated"); - } - } - } - // update the map : output_tensor -> Operation - for (int out_idx = 0; out_idx < op_def->output_size(); ++out_idx) { - DataType dt; - if (op_def->output_type_size() == op_def->output_size()) { - dt = op_def->output_type(out_idx); - } else { - dt = static_cast( - ProtoArgHelper::GetOptionalArg( - *op_def, "T", static_cast(DataType::DT_FLOAT))); - } - output_mem_map[op_def->output(out_idx)] = out_mem_type; - output_map.emplace( - op_def->output(out_idx), - InternalOutputInfo( - out_mem_type, - dt, - default_output_df, - op_def->output_shape().empty() ? - std::vector() : - std::vector( - op_def->output_shape(out_idx).dims().begin(), - op_def->output_shape(out_idx).dims().end()), - static_cast(operators_.size()))); - } + auto op_device_type = static_cast(op_def->device_type()); + if (op_device_type == target_device_->device_type()) { + construct_context.set_device(target_device_); + } else if (op_device_type == DeviceType::CPU) { + construct_context.set_device(cpu_device_.get()); + } else { + LOG(FATAL) << "Encounter unexpected error: " + << op_device_type << " vs " << target_device_->device_type(); } -#endif // MACE_ENABLE_OPENCL + construct_context.set_operator_def(op_def); + + auto op = op_registry->CreateOperation(&construct_context, + op_device_type); operators_.emplace_back(std::move(op)); // where to do graph reference count. mem_optimizer->UpdateTensorRef(op_def.get()); - } #ifdef MACE_ENABLE_OPENCL - // Transform the output tensor if necessary - if (target_device_->device_type() == DeviceType::GPU) { - for (auto &output_info : net_def->output_info()) { - auto &internal_output_info = output_map.at(output_info.name()); - if ((internal_output_info.mem_type != target_mem_type && - internal_output_info.mem_type != MemoryType::CPU_BUFFER) || - internal_output_info.dtype != output_info.data_type()) { - VLOG(1) << "Add Transform operation to transform output tensor '" - << output_info.name() << "', from memory type " - << internal_output_info.mem_type - << " to " << target_mem_type - << ", from Data Type " << internal_output_info.dtype - << " to " << output_info.data_type(); - std::string t_output_name = TransformedName(output_info.name(), - target_mem_type); - auto output_op_def = - operators_[internal_output_info.op_idx]->operator_def(); - int output_size = output_op_def->output_size(); - for (int i = 0; i < output_size; ++i) { - if (output_op_def->output(i) == output_info.name()) { - output_op_def->set_output(i, t_output_name); - // update the output : mem_type map - output_mem_map[t_output_name] = output_mem_map[output_info.name()]; - output_mem_map[output_info.name()] = target_mem_type; - } - } - bool output_has_data_format = - static_cast(output_info.data_format()); - auto transform_op_def = OpenCLUtil::CreateTransformOpDef( - t_output_name, - internal_output_info.shape, - output_info.name(), - output_info.data_type(), - OpenCLBufferType::IN_OUT_CHANNEL, - target_mem_type, - output_has_data_format); - auto transform_op = CreateOperation( - op_registry, - &construct_context, - transform_op_def, - output_has_data_format); - operators_.emplace_back(std::move(transform_op)); - // where to do graph reference count. - mem_optimizer->UpdateTensorRef(transform_op_def.get()); + if (target_device_->device_type() == DeviceType::GPU) { + // update the map : output_tensor -> Operation + MemoryType out_mem_type = + static_cast( + ProtoArgHelper::GetOptionalArg( + net_def->op(idx), OutputMemoryTypeTagName(), + static_cast(MemoryType::CPU_BUFFER))); + for (int out_idx = 0; out_idx < op_def->output_size(); ++out_idx) { + output_mem_map[op_def->output(out_idx)] = out_mem_type; } } - } #endif // MACE_ENABLE_OPENCL + } // Update output tensor reference for (auto &output_info : net_def->output_info()) { mem_optimizer->UpdateTensorRef(output_info.name()); diff --git a/mace/core/net.h b/mace/core/net.h index 788eb611..18ec5134 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -54,14 +54,6 @@ class SerialNet : public NetBase { MaceStatus Run(RunMetadata *run_metadata = nullptr) override; - private: - std::unique_ptr CreateOperation( - const OpRegistryBase *op_registry, - OpConstructContext *construct_context, - std::shared_ptr op_def, - bool has_data_format, - bool is_quantize_model = false); - protected: Workspace *ws_; Device *target_device_; diff --git a/mace/core/net_def_adapter.cc b/mace/core/net_def_adapter.cc new file mode 100644 index 00000000..fe89e810 --- /dev/null +++ b/mace/core/net_def_adapter.cc @@ -0,0 +1,651 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/net_def_adapter.h" + +#include +#include + +#include "mace/core/operator.h" +#include "mace/utils/math.h" +#ifdef MACE_ENABLE_OPENCL +#include "mace/core/runtime/opencl/opencl_util.h" +#endif // MACE_ENABLE_OPENCL +namespace mace { + +namespace { +DataFormat GetDefaultDataFormat(DeviceType device_type, + bool is_quantized_model) { + if (device_type == CPU) { + if (is_quantized_model) { + return DataFormat::NHWC; + } else { + return DataFormat::NCHW; + } + } else if (device_type == GPU) { + return DataFormat::NHWC; + } else { + LOG(FATAL) << "MACE do not support the device " << device_type; + return DataFormat::DF_NONE; + } +} + +template +std::string TransformedName(const std::string &input_name, + const std::string &tag, + const T value) { + std::stringstream ss; + ss << input_name << "_" << tag << "_" << value; + return ss.str(); +} + +bool TransformRequiredOp(const std::string &op_type) { + static const std::unordered_set kNoTransformOp = { + "Shape", "InferConv2dShape" + }; + return kNoTransformOp.count(op_type) == 0; +} + +void BuildTransposeOpDef( + const std::string &input_name, + const std::string &output_name, + const std::vector &output_shape, + const std::vector dst_dims, + const mace::DataType dt, + DeviceType device_type, + OperatorDef *op_def) { + std::string op_name = "mace_node_" + output_name; + op_def->set_name(op_name); + op_def->set_type("Transpose"); + op_def->add_input(input_name); + op_def->add_output(output_name); + op_def->set_device_type(device_type); + Argument *arg = op_def->add_arg(); + arg->set_name("dims"); + for (auto dim : dst_dims) { + arg->add_ints(dim); + } + arg = op_def->add_arg(); + arg->set_name("T"); + arg->set_i(static_cast(dt)); + if (!output_shape.empty()) { + OutputShape *shape = op_def->add_output_shape(); + for (auto value : output_shape) { + shape->add_dims(value); + } + } +} + +} // namespace + +NetDefAdapter::NetDefAdapter(const mace::OpRegistryBase *op_registry, + const mace::Workspace *ws) + : op_registry_(op_registry), ws_(ws) {} + +// Adapt original net_def to a better net. +// 1. Adapt device: choose best device for every op in the net. +// 2. Adapt data type: Add data type related transform ops +// for mixing precision. +// 3. Adapt data format: confirm data format of every op +// and add transpose if necessary. +// 4. Adapt memory type: Add BufferTransform if necessary +// for transforming memory type between ops. +MaceStatus NetDefAdapter::AdaptNetDef( + const mace::NetDef *net_def, + mace::Device *target_device, + NetDef *target_net_def) { + MACE_LATENCY_LOGGER(1, "Adapting original NetDef"); + // Copy from original op_def, leave ops alone. + target_net_def->mutable_arg()->CopyFrom(net_def->arg()); + target_net_def->mutable_tensors()->CopyFrom(net_def->tensors()); + target_net_def->mutable_input_info()->CopyFrom(net_def->input_info()); + target_net_def->mutable_output_info()->CopyFrom(net_def->output_info()); + + std::unique_ptr cpu_device = make_unique( + target_device->cpu_runtime()->num_threads(), + target_device->cpu_runtime()->policy(), + target_device->cpu_runtime()->use_gemmlowp()); + + // quantize model flag + bool is_quantized_model = IsQuantizedModel(*net_def); + // Const tensors(filter) -> shape + std::unordered_map> tensor_shape_map; + // Output tensors -> information + TensorInfoMap output_map; + // output tensor : related information + std::unordered_set transformed_set; + + for (auto &tensor : net_def->tensors()) { + tensor_shape_map[tensor.name()] = + std::vector(tensor.dims().begin(), tensor.dims().end()); + } + + int input_size = target_net_def->input_info_size(); + for (int i = 0; i < input_size; ++i) { + auto input_info = target_net_def->mutable_input_info(i); + MemoryType mem_type = MemoryType::CPU_BUFFER; + if (target_device->device_type() == DeviceType::CPU) { + mem_type = MemoryType::CPU_BUFFER; + } else if (target_device->device_type() == DeviceType::GPU) { + mem_type = MemoryType::GPU_BUFFER; + } else { + LOG(FATAL) << "MACE do not support the device type: " + << target_device->device_type(); + } + DataFormat input_data_format = static_cast( + input_info->data_format()); + DataFormat expected_data_format = GetDefaultDataFormat( + target_device->device_type(), is_quantized_model); + std::vector input_shape = + std::vector(input_info->dims().begin(), + input_info->dims().end()); + if (input_data_format != DataFormat::DF_NONE + && input_data_format != expected_data_format + && input_shape.size() == 4) { + if (input_data_format == DataFormat::NHWC + && expected_data_format == DataFormat::NCHW) { + std::vector dst_dims = {0, 3, 1, 2}; + input_data_format = DataFormat::NCHW; + input_shape = TransposeShape(input_shape, dst_dims); + } else if (input_data_format == DataFormat::NCHW + && expected_data_format == DataFormat::NHWC) { + std::vector dst_dims = {0, 2, 3, 1}; + input_data_format = DataFormat::NHWC; + input_shape = TransposeShape(input_shape, dst_dims); + } + input_info->set_data_format(input_data_format); + int input_shape_size = input_shape.size(); + for (int j = 0; j < input_shape_size; ++j) { + input_info->set_dims(j, input_shape[j]); + } + } + output_map.emplace(input_info->name(), InternalOutputInfo( + mem_type, input_info->data_type(), + input_data_format, input_shape, -1)); + } + + OpConditionContext context(ws_, &tensor_shape_map); + DataFormat op_output_data_format; + MemoryType op_output_mem_type; + for (int idx = 0; idx < net_def->op_size(); ++idx) { + OperatorDef op_def(net_def->op(idx)); + context.set_operator_def(&op_def); + // Select device + MACE_RETURN_IF_ERROR(this->AdaptDevice(&context, + target_device, + cpu_device.get(), + output_map, + target_net_def, + &op_def)); + + // Adapt data type + MACE_RETURN_IF_ERROR(this->AdaptDataType(&context, + &op_def)); + + if (op_def.device_type() == DeviceType::GPU) { + MACE_RETURN_IF_ERROR(this->AdaptDataFormat(&context, + &op_def, + is_quantized_model, + &output_map, + &transformed_set, + &op_output_data_format, + target_net_def)); + MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context, + &op_def, + &output_map, + &transformed_set, + &op_output_mem_type, + target_net_def)); + } else { + MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context, + &op_def, + &output_map, + &transformed_set, + &op_output_mem_type, + target_net_def)); + MACE_RETURN_IF_ERROR(this->AdaptDataFormat(&context, + &op_def, + is_quantized_model, + &output_map, + &transformed_set, + &op_output_data_format, + target_net_def)); + } + + int output_size = op_def.output_size(); + for (int out_idx = 0; out_idx < output_size; ++out_idx) { + DataType dt; + if (op_def.output_type_size() == op_def.output_size()) { + dt = op_def.output_type(out_idx); + } else { + dt = static_cast( + ProtoArgHelper::GetOptionalArg( + op_def, "T", static_cast(DataType::DT_FLOAT))); + } + output_map.emplace( + op_def.output(out_idx), + InternalOutputInfo( + op_output_mem_type, + dt, + op_output_data_format, + op_def.output_shape().empty() ? + std::vector() : + std::vector( + op_def.output_shape(out_idx).dims().begin(), + op_def.output_shape(out_idx).dims().end()), + target_net_def->op_size())); + } + // Add op to target net + target_net_def->add_op()->CopyFrom(op_def); + } + +#ifdef MACE_ENABLE_OPENCL + if (target_device->device_type() == DeviceType::GPU) { + // Add buffer transform for GPU if necessary + MemoryType target_mem_type = MemoryType::GPU_BUFFER; + for (auto &output_info : net_def->output_info()) { + auto &internal_output_info = output_map.at(output_info.name()); + if ((internal_output_info.mem_type != target_mem_type && + internal_output_info.mem_type != MemoryType::CPU_BUFFER) || + internal_output_info.dtype != output_info.data_type()) { + VLOG(1) << "Add Transform operation to transform output tensor '" + << output_info.name() << "', from memory type " + << internal_output_info.mem_type + << " to " << target_mem_type + << ", from Data Type " << internal_output_info.dtype + << " to " << output_info.data_type(); + std::string t_output_name = TransformedName(output_info.name(), + "mem_type", + target_mem_type); + auto output_op_def = target_net_def->mutable_op( + internal_output_info.op_idx); + int output_size = output_op_def->output_size(); + for (int i = 0; i < output_size; ++i) { + if (output_op_def->output(i) == output_info.name()) { + output_op_def->set_output(i, t_output_name); + } + } + auto transformed_op_def = target_net_def->add_op(); + OpenCLUtil::BuildTransformOpDef( + t_output_name, + internal_output_info.shape, + output_info.name(), + output_info.data_type(), + OpenCLBufferType::IN_OUT_CHANNEL, + target_mem_type, + internal_output_info.data_format, + transformed_op_def); + // set data format arg + SetProtoArg(transformed_op_def, + "data_format", + internal_output_info.data_format); + // set output memory type argument + SetProtoArg(transformed_op_def, + OutputMemoryTypeTagName(), + target_mem_type); + } + } + } +#endif // MACE_ENABLE_OPENCL + + VLOG(1) << DebugString(target_net_def); + return MaceStatus::MACE_SUCCESS; +} + +MaceStatus NetDefAdapter::AdaptDevice(OpConditionContext *context, + Device *target_device, + Device *cpu_device, + const TensorInfoMap &output_map, + const NetDef *net_def, + OperatorDef *op_def) { + VLOG(1) << "Adapt device for op " << op_def->name(); + DeviceType target_device_type = target_device->device_type(); + DeviceType device_type = DeviceType::CPU; + context->set_device(cpu_device); + if (target_device_type != DeviceType::CPU) { + std::vector producer_devices; + for (auto input : op_def->input()) { + if (output_map.count(input) == 1) { + if (output_map.at(input).op_idx < 0) { + producer_devices.push_back(target_device_type); + } else { + producer_devices.push_back( + static_cast( + net_def->op(output_map.at(input).op_idx).device_type())); + } + } + } + // Get available devices + auto available_devices = + op_registry_->AvailableDevices(op_def->type(), context); + device_type = net_optimizer_.SelectBestDevice(op_def, + target_device_type, + available_devices, + producer_devices); + if (device_type == target_device_type) { + context->set_device(target_device); + } + } + op_def->set_device_type(device_type); + return MaceStatus::MACE_SUCCESS; +} + +MaceStatus NetDefAdapter::AdaptDataType(mace::OpConditionContext *context, + mace::OperatorDef *op_def) { + MACE_UNUSED(context); + // Adjust data type of op ran on CPU + DataType dtype = static_cast( + ProtoArgHelper::GetOptionalArg( + *op_def, "T", static_cast(DT_FLOAT))); + if (op_def->device_type() == DeviceType::CPU && dtype == DT_HALF) { + SetProtoArg(op_def, "T", static_cast(DataType::DT_FLOAT)); + } + return MaceStatus::MACE_SUCCESS; +} + +MaceStatus NetDefAdapter::AdaptDataFormat( + mace::OpConditionContext *context, + mace::OperatorDef *op_def, + bool is_quantized_model, + TensorInfoMap *output_map, + std::unordered_set *transformed_set, + DataFormat *op_output_df, + mace::NetDef *target_net_def) { + VLOG(1) << "Adapt data format for op " << op_def->name(); + MACE_UNUSED(context); + DataFormat op_data_format = + static_cast(ProtoArgHelper::GetOptionalArg( + *op_def, "data_format", 0)); + // adjust the data format of operation + if (op_data_format == DataFormat::DF_AUTO) { + op_data_format = GetDefaultDataFormat( + static_cast(op_def->device_type()), is_quantized_model); + SetProtoArg(op_def, "data_format", static_cast(op_data_format)); + if (op_data_format == DataFormat::NCHW) { + int output_shape_size = op_def->output_shape_size(); + for (int i = 0; i < output_shape_size; ++i) { + auto output_shape = op_def->mutable_output_shape(i); + if (output_shape->dims_size() == 4) { + // transpose output shape format from NHWC to NCHW + int64_t height = output_shape->dims(1); + int64_t width = output_shape->dims(2); + output_shape->set_dims(1, output_shape->dims(3)); + output_shape->set_dims(2, height); + output_shape->set_dims(3, width); + } + } + } + } + *op_output_df = op_data_format; + + // the output memory type of transpose op is based on the consumer op's device + MemoryType target_mem_type = MemoryType::CPU_BUFFER; + if (op_def->device_type() == DeviceType::GPU) { + target_mem_type = MemoryType::GPU_BUFFER; + } + // Use op's data format as inputs' data format for now. + // Could move the logic to OpRegistry if necessary. + DataFormat src_df, dst_df; + int input_size = op_def->input_size(); + for (int i = 0; i < input_size; ++i) { + if (output_map->count(op_def->input(i)) == 0) { + // check this input is const tensor(filter) + MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr + && ws_->GetTensor(op_def->input(i))->is_weight(), + "Tensor ", op_def->input(i), " of ", + op_def->name(), " is not allocated by Workspace ahead"); + continue; + } + src_df = output_map->at(op_def->input(i)).data_format; + dst_df = op_data_format; + if (src_df == DataFormat::DF_NONE + || dst_df == DataFormat::DF_NONE + || output_map->at(op_def->input(i)).shape.size() != 4) { + continue; + } + if (src_df != dst_df) { + std::string transformed_name = TransformedName(op_def->input(i), + "data_format", dst_df); + if (transformed_set->count(transformed_name) == 0) { + VLOG(1) << "Add Transpose operation " << op_def->name() + << " to transpose tensor " + << op_def->input(i) << "', from data format " + << src_df << " to " << dst_df; + // Only support transpose between NHWC and NCHW for now. + std::vector dst_dims; + if (src_df == DataFormat::NCHW && dst_df == DataFormat::NHWC) { + dst_dims = {0, 2, 3, 1}; + } else if (src_df == DataFormat::NHWC && dst_df == DataFormat::NCHW) { + dst_dims = {0, 3, 1, 2}; + } else { + LOG(FATAL) << "Encounter unsupported data format transpose from " + << src_df << " to " << dst_df; + } + auto &input_info = output_map->at(op_def->input(i)); + auto output_shape = input_info.shape.empty() ? + std::vector() : + TransposeShape(input_info.shape, + dst_dims); + OperatorDef *transpose_op_def = target_net_def->add_op(); + BuildTransposeOpDef( + op_def->input(i), + transformed_name, + output_shape, + dst_dims, + input_info.dtype, + DeviceType::CPU, + transpose_op_def); + // set data format arg + SetProtoArg(transpose_op_def, + "data_format", + dst_df); + // set output memory type argument + SetProtoArg(transpose_op_def, + OutputMemoryTypeTagName(), + target_mem_type); + + // update output information map + output_map->emplace( + transformed_name, + InternalOutputInfo( + target_mem_type, + input_info.dtype, + dst_df, + output_shape, + target_net_def->op_size() - 1)); + // record transformed tensors + transformed_set->insert(transformed_name); + } + // update original op_def's input + op_def->set_input(i, transformed_name); + } + } + return MaceStatus::MACE_SUCCESS; +} + +MaceStatus NetDefAdapter::AdaptMemoryType( + mace::OpConditionContext *context, + mace::OperatorDef *op_def, + mace::NetDefAdapter::TensorInfoMap *output_map, + std::unordered_set *transformed_set, + MemoryType *op_output_mem_types, + mace::NetDef *target_net_def) { + VLOG(1) << "Adapt memory type for op " << op_def->name(); + // Get expected output memory type + // (only support one kind of memory type for multiple outputs) + op_registry_->GetInOutMemoryTypes(op_def->type(), context); +#ifdef MACE_ENABLE_OPENCL + int input_size = op_def->input_size(); + // if op is memory-unused op, no transformation + if (TransformRequiredOp(op_def->type())) { + for (int i = 0; i < input_size; ++i) { + if (output_map->count(op_def->input(i)) == 0) { + MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr + && ws_->GetTensor(op_def->input(i))->is_weight(), + "Tensor ", op_def->input(i), " of ", + op_def->name(), " not allocated"); + continue; + } + auto &input_info = output_map->at(op_def->input(i)); + if (input_info.data_format == DataFormat::DF_NONE + || input_info.shape.size() != 4) { + continue; + } + // check whether to do transform + MemoryType src_mem_type = input_info.mem_type; + MemoryType dst_mem_type = context->GetInputMemType(i); + if (src_mem_type != dst_mem_type) { + auto transformed_name = TransformedName(op_def->input(i), + "mem_type", + dst_mem_type); + // check whether the tensor has been transformed + if (transformed_set->count(transformed_name) == 0) { + VLOG(1) << "Add Transform operation " << op_def->name() + << " to transform tensor " + << op_def->input(i) << "', from memory type " + << input_info.mem_type << " to " + << dst_mem_type; + OperatorDef *transformed_op_def = target_net_def->add_op(); + OpenCLUtil::BuildTransformOpDef( + op_def->input(i), + input_info.shape, + transformed_name, + context->GetInputDataType(i), + context->GetInputOpenCLBufferType(i), + dst_mem_type, + input_info.data_format, + transformed_op_def); + // set data format arg + SetProtoArg(transformed_op_def, + "data_format", + input_info.data_format); + // set output memory type argument + SetProtoArg(transformed_op_def, + OutputMemoryTypeTagName(), + dst_mem_type); + + // update output information map + output_map->emplace( + transformed_name, + InternalOutputInfo( + dst_mem_type, + context->GetInputDataType(i), + input_info.data_format, + input_info.shape, + target_net_def->op_size() - 1)); + // record transformed tensors + transformed_set->insert(transformed_name); + } + // update original op_def's input + op_def->set_input(i, transformed_name); + } + } + } +#else + MACE_UNUSED(output_map); + MACE_UNUSED(transformed_set); + MACE_UNUSED(target_net_def); +#endif // MACE_ENABLE_OPENCL + *op_output_mem_types = context->output_mem_type(); + SetProtoArg(op_def, + OutputMemoryTypeTagName(), + context->output_mem_type()); + return MaceStatus::MACE_SUCCESS; +} + +std::string NetDefAdapter::DebugString(const mace::NetDef *net_def) { + std::stringstream sstream; + auto DeviceTypeToStrFunc = [](DeviceType device_type) -> std::string { + if (device_type == DeviceType::CPU) { + return "CPU"; + } else if (device_type == DeviceType::GPU) { + return "GPU"; + } else { + return "Unknown"; + } + }; + auto MemoryTypeToStrFunc = [](MemoryType type) -> std::string { + if (type == MemoryType::CPU_BUFFER) { + return "CPU_BUFFER"; + } else if (type == MemoryType::GPU_BUFFER) { + return "GPU_BUFFER"; + } else if (type == MemoryType::GPU_IMAGE) { + return "GPU_IMAGE"; + } else { + return "Unknown"; + } + }; + auto DataFormatToStrFunc = [](DataFormat type) -> std::string { + if (type == DataFormat::NHWC) { + return "NHWC"; + } else if (type == DataFormat::NCHW) { + return "NCHW"; + } else if (type == DataFormat::DF_NONE) { + return "DF_NONE"; + } else if (type == DataFormat::DF_AUTO) { + return "DT_AUTO"; + } else if (type == DataFormat::OIHW) { + return "OIHW"; + } else { + return "Unknown"; + } + }; + for (auto &op : net_def->op()) { + std::string device_type = DeviceTypeToStrFunc( + static_cast(op.device_type())); + std::string data_type = DataTypeToString(static_cast( + ProtoArgHelper::GetOptionalArg( + op, "T", static_cast(DT_FLOAT)))); + std::string mem_type = MemoryTypeToStrFunc( + static_cast( + ProtoArgHelper::GetOptionalArg( + op, OutputMemoryTypeTagName(), + static_cast(MemoryType::CPU_BUFFER)))); + std::string data_format = DataFormatToStrFunc( + static_cast( + ProtoArgHelper::GetOptionalArg( + op, "data_format", 0))); + + sstream << std::endl; + sstream << "{" << std::endl; + sstream << " name: " << op.name() << std::endl; + sstream << " type: " << op.type() << std::endl; + sstream << " device: " << device_type << std::endl; + sstream << " data type: " << data_type << std::endl; + sstream << " data format: " << data_format << std::endl; + sstream << " memory type: " << mem_type << std::endl; + sstream << " inputs: ["; + for (auto input : op.input()) { + sstream << input << ", "; + } + sstream << "]" << std::endl; + sstream << " outputs: ["; + for (auto output : op.output()) { + sstream << output << ", "; + } + sstream << "]" << std::endl; + sstream << " output shapes: ["; + for (auto output_shape : op.output_shape()) { + sstream << "("; + for (auto dim : output_shape.dims()) + sstream << dim << ","; + sstream << ") "; + } + sstream << "]" << std::endl; + sstream << "}"; + } + return sstream.str(); +} + +} // namespace mace diff --git a/mace/core/net_def_adapter.h b/mace/core/net_def_adapter.h new file mode 100644 index 00000000..7f3a6754 --- /dev/null +++ b/mace/core/net_def_adapter.h @@ -0,0 +1,110 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_CORE_NET_DEF_ADAPTER_H_ +#define MACE_CORE_NET_DEF_ADAPTER_H_ +#include +#include +#include +#include +#include + +#include "mace/core/types.h" +#include "mace/proto/mace.pb.h" +#include "mace/port/port.h" +#include "mace/core/operator.h" +#include "mace/core/net_optimizer.h" + +namespace mace { + +class OpRegistryBase; +class Workspace; +class Device; + +/** + * Conventions: + * 1. DataFormat::DT_AUTO stands for formatted (NHWC or NCHW) + * 2. if Op with DataFormat::DT_AUTO, the arguments of this op + * is formatted to NHWC + */ +class NetDefAdapter { + public: + NetDefAdapter(const OpRegistryBase *op_registry, + const Workspace *ws); + MaceStatus AdaptNetDef( + const NetDef *net_def, + Device *target_device, + NetDef *target_net_def); + + public: + NetDefAdapter(const NetDefAdapter&) = delete; + NetDefAdapter(const NetDefAdapter&&) = delete; + NetDefAdapter &operator=(const NetDefAdapter &) = delete; + NetDefAdapter &operator=(const NetDefAdapter &&) = delete; + + private: + struct InternalOutputInfo { + InternalOutputInfo(const MemoryType mem_type, + const DataType dtype, + const DataFormat data_format, + const std::vector &shape, + int op_idx) + : mem_type(mem_type), dtype(dtype), data_format(data_format), + shape(shape), op_idx(op_idx) {} + + MemoryType mem_type; + DataType dtype; + DataFormat data_format; + std::vector shape; // tensor shape + int op_idx; // operation which generate the tensor + }; + + typedef std::unordered_map TensorInfoMap; + + private: + MaceStatus AdaptDevice(OpConditionContext *context, + Device *target_device, + Device *cpu_device, + const TensorInfoMap &output_map, + const NetDef *net_def, + OperatorDef *op); + MaceStatus AdaptDataType(OpConditionContext *context, + OperatorDef *op); + MaceStatus AdaptDataFormat( + OpConditionContext *context, + OperatorDef *op, + bool is_quantized_model, + TensorInfoMap *output_map, + std::unordered_set *transformed_set, + DataFormat *op_output_df, + NetDef *target_net_def); + + MaceStatus AdaptMemoryType( + mace::OpConditionContext *context, + mace::OperatorDef *op_def, + TensorInfoMap *output_map, + std::unordered_set *transformed_set, + MemoryType *op_output_mem_types, + mace::NetDef *target_net_def); + + std::string DebugString(const NetDef *net_def); + + private: + const OpRegistryBase *op_registry_; + const Workspace *ws_; + NetOptimizer net_optimizer_; +}; + +} // namespace mace +#endif // MACE_CORE_NET_DEF_ADAPTER_H_ diff --git a/mace/core/net_optimizer.cc b/mace/core/net_optimizer.cc new file mode 100644 index 00000000..565a42c1 --- /dev/null +++ b/mace/core/net_optimizer.cc @@ -0,0 +1,50 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/net_optimizer.h" + +#include + +namespace mace { + +DeviceType NetOptimizer::SelectBestDevice( + const mace::OperatorDef *op_def, + DeviceType target_device_type, + const std::set &available_devices, + const std::vector &inputs_op_devices) { + static const std::set kComputeIntensiveOps = { + "Conv2D", "DepthwiseConv2d", "Deconv2D", "DepthwiseDeconv2d", + "FullyConnected" + }; + // CPU is the device to fall back + DeviceType best_device = DeviceType::CPU; + if (available_devices.count(target_device_type) == 1) { + best_device = target_device_type; + } + if (best_device == DeviceType::CPU) { + return best_device; + } + // Put compute-intensive ops in target device + if (kComputeIntensiveOps.count(op_def->type()) == 1) { + return best_device; + } + // Greedy strategy: Use input op's device type as current op's device + for (auto device_type : inputs_op_devices) { + if (device_type != best_device) { + best_device = device_type; + } + } + return best_device; +} +} // namespace mace diff --git a/mace/core/net_optimizer.h b/mace/core/net_optimizer.h new file mode 100644 index 00000000..8ec8dc23 --- /dev/null +++ b/mace/core/net_optimizer.h @@ -0,0 +1,35 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_CORE_NET_OPTIMIZER_H_ +#define MACE_CORE_NET_OPTIMIZER_H_ + +#include +#include + +#include "mace/port/port.h" +#include "mace/proto/mace.pb.h" + +namespace mace { + +class NetOptimizer { + public: + DeviceType SelectBestDevice(const OperatorDef *op_def, + DeviceType target_device, + const std::set &available_devices, + const std::vector &inputs_op_devices); +}; + +} // namespace mace +#endif // MACE_CORE_NET_OPTIMIZER_H_ diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 8fae1bd8..275189a7 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -20,34 +20,21 @@ #include "mace/core/operator.h" namespace mace { - -OpConstructContext::OpConstructContext(Workspace *ws) - : operator_def_(nullptr), - ws_(ws), - device_(nullptr), - tensor_shape_info_(nullptr) {} - -OpConstructContext::OpConstructContext( - mace::Workspace *ws, - mace::OpConstructContext::TensorShapeMap *info) +OpConditionContext::OpConditionContext( + const mace::Workspace *ws, + mace::OpConditionContext::TensorShapeMap *info) : operator_def_(nullptr), ws_(ws), device_(nullptr), tensor_shape_info_(info) {} -void OpConstructContext::set_operator_def( - std::shared_ptr operator_def) { +void OpConditionContext::set_operator_def( + const mace::OperatorDef *operator_def) { operator_def_ = operator_def; input_data_types_.clear(); } -void OpConstructContext::set_output_mem_type(mace::MemoryType type) { - MACE_CHECK(operator_def_ != nullptr); - output_mem_type_ = type; - input_mem_types_.clear(); -} - -void OpConstructContext::SetInputInfo(size_t idx, +void OpConditionContext::SetInputInfo(size_t idx, mace::MemoryType mem_type, mace::DataType dt) { if (input_mem_types_.empty()) { @@ -66,7 +53,13 @@ void OpConstructContext::SetInputInfo(size_t idx, input_data_types_[idx] = dt; } -MemoryType OpConstructContext::GetInputMemType(size_t idx) const { +void OpConditionContext::set_output_mem_type(mace::MemoryType type) { + MACE_CHECK(operator_def_ != nullptr); + output_mem_type_ = type; + input_mem_types_.clear(); +} + +MemoryType OpConditionContext::GetInputMemType(size_t idx) const { if (input_mem_types_.empty()) { return output_mem_type_; } @@ -75,7 +68,7 @@ MemoryType OpConstructContext::GetInputMemType(size_t idx) const { return input_mem_types_[idx]; } -DataType OpConstructContext::GetInputDataType(size_t idx) const { +DataType OpConditionContext::GetInputDataType(size_t idx) const { if (input_data_types_.empty()) { // the default inputs' data types are same as operation's data type. return static_cast( @@ -87,17 +80,17 @@ DataType OpConstructContext::GetInputDataType(size_t idx) const { } #ifdef MACE_ENABLE_OPENCL -void OpConstructContext::SetInputOpenCLBufferType( +void OpConditionContext::SetInputOpenCLBufferType( size_t idx, OpenCLBufferType buffer_type) { if (input_opencl_buffer_types_.empty()) { // the default inputs' memory types are same as output memory type. input_opencl_buffer_types_.resize(operator_def_->input_size(), - OpenCLBufferType::IN_OUT_CHANNEL); + OpenCLBufferType::IN_OUT_CHANNEL); } MACE_CHECK(idx < input_opencl_buffer_types_.size()); input_opencl_buffer_types_[idx] = buffer_type; } -OpenCLBufferType OpConstructContext::GetInputOpenCLBufferType( +OpenCLBufferType OpConditionContext::GetInputOpenCLBufferType( size_t idx) const { if (input_opencl_buffer_types_.empty()) { return OpenCLBufferType::IN_OUT_CHANNEL; @@ -107,6 +100,16 @@ OpenCLBufferType OpConstructContext::GetInputOpenCLBufferType( } #endif // MACE_ENABLE_OPENCL +OpConstructContext::OpConstructContext(Workspace *ws) + : operator_def_(nullptr), + ws_(ws), + device_(nullptr) {} + +void OpConstructContext::set_operator_def( + std::shared_ptr operator_def) { + operator_def_ = operator_def; +} + OpInitContext::OpInitContext(Workspace *ws, Device *device) : ws_(ws), device_(device) {} @@ -202,16 +205,26 @@ const std::string OpKeyBuilder::Build() { } // namespace OpRegistrationInfo::OpRegistrationInfo() { - device_placer = [this](OpConstructContext *context) -> std::set { - auto op = context->operator_def(); - // The GPU ops only support 4D In/Out tensor by default - if (this->devices.count(DeviceType::CPU) == 1 && - op->output_shape_size() == op->output_size() && - op->output_shape(0).dims_size() != 4) { - return { DeviceType::CPU }; - } + // default device type placer + device_placer = [this](OpConditionContext *context) -> std::set { + MACE_UNUSED(context); return this->devices; }; + + // default input and output memory type setter + memory_type_setter = [](OpConditionContext *context) -> void { + if (context->device()->device_type() == DeviceType::GPU) { +#ifdef MACE_ENABLE_OPENCL + if (context->device()->gpu_runtime()->UseImageMemory()) { + context->set_output_mem_type(MemoryType::GPU_IMAGE); + } else { + context->set_output_mem_type(MemoryType::GPU_BUFFER); + } +#endif // MACE_ENABLE_OPENCL + } else { + context->set_output_mem_type(MemoryType::CPU_BUFFER); + } + }; } void OpRegistrationInfo::AddDevice(mace::DeviceType device) { @@ -255,13 +268,21 @@ MaceStatus OpRegistryBase::Register( } const std::set OpRegistryBase::AvailableDevices( - const std::string &op_type, OpConstructContext *context) const { + const std::string &op_type, OpConditionContext *context) const { MACE_CHECK(registry_.count(op_type) != 0, op_type, " operation is not registered."); return registry_.at(op_type)->device_placer(context); } +void OpRegistryBase::GetInOutMemoryTypes( + const std::string &op_type, + mace::OpConditionContext *context) const { + MACE_CHECK(registry_.count(op_type) != 0, + op_type, " operation is not registered."); + return registry_.at(op_type)->memory_type_setter(context); +} + std::unique_ptr OpRegistryBase::CreateOperation( OpConstructContext *context, DeviceType device_type) const { @@ -269,15 +290,6 @@ std::unique_ptr OpRegistryBase::CreateOperation( DataType dtype = static_cast( ProtoArgHelper::GetOptionalArg( *operator_def, "T", static_cast(DT_FLOAT))); - if (device_type == DeviceType::CPU && dtype == DT_HALF) { - int arg_size = operator_def->arg_size(); - for (int i = 0; i < arg_size; ++i) { - if (operator_def->arg(i).name() == "T") { - operator_def->mutable_arg(i)->set_i(DT_FLOAT); - } - } - dtype = DT_FLOAT; - } VLOG(1) << "Creating operator " << operator_def->name() << "(" << operator_def->type() << "<" << dtype << ">" << ") on " << device_type; @@ -308,9 +320,20 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc( return *this; } +OpConditionBuilder& OpConditionBuilder::SetInputMemoryTypeSetter( + mace::OpRegistrationInfo::MemoryTypeSetter setter) { + memory_type_setter_ = setter; + return *this; +} + void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const { - if (info != nullptr && placer_) { - info->device_placer = placer_; + if (info != nullptr) { + if (placer_) { + info->device_placer = placer_; + } + if (memory_type_setter_) { + info->memory_type_setter = memory_type_setter_; + } } } diff --git a/mace/core/operator.h b/mace/core/operator.h index e59af9ab..35effdc5 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -32,22 +32,20 @@ namespace mace { -// memory_optimizer, device -class OpConstructContext { - typedef std::unordered_map> TensorShapeMap; - +// OpConditionContext has all information used for choosing proper Op +class OpConditionContext { public: - explicit OpConstructContext(Workspace *ws); - OpConstructContext(Workspace *ws, TensorShapeMap *info); - ~OpConstructContext() = default; + typedef std::unordered_map> TensorShapeMap; + OpConditionContext(const Workspace *ws, TensorShapeMap *info); + ~OpConditionContext() = default; - void set_operator_def(std::shared_ptr operator_def); + void set_operator_def(const OperatorDef* operator_def); - inline std::shared_ptr operator_def() const { + inline const OperatorDef *operator_def() const { return operator_def_; } - inline Workspace *workspace() const { + inline const Workspace *workspace() const { return ws_; } @@ -81,8 +79,8 @@ class OpConstructContext { #endif // MACE_ENABLE_OPENCL private: - std::shared_ptr operator_def_; - Workspace *ws_; + const OperatorDef *operator_def_; + const Workspace *ws_; Device *device_; TensorShapeMap *tensor_shape_info_; // used for memory transform @@ -94,6 +92,38 @@ class OpConstructContext { #endif // MACE_ENABLE_OPENCL }; +// memory_optimizer, device +class OpConstructContext { + typedef std::unordered_map> TensorShapeMap; + + public: + explicit OpConstructContext(Workspace *ws); + ~OpConstructContext() = default; + + void set_operator_def(std::shared_ptr operator_def); + + inline std::shared_ptr operator_def() const { + return operator_def_; + } + + inline Workspace *workspace() const { + return ws_; + } + + inline void set_device(Device* device) { + device_ = device; + } + + inline Device *device() const { + return device_; + } + + private: + std::shared_ptr operator_def_; + Workspace *ws_; + Device *device_; +}; + // memory_optimizer, device class OpInitContext { public: @@ -207,8 +237,11 @@ struct OpRegistrationInfo { public: typedef std::function(OpConstructContext *)> OpCreator; - typedef std::function(OpConstructContext *)> + typedef std::function(OpConditionContext *)> DevicePlacer; + typedef std::function MemoryTypeSetter; + typedef std::function(OpConditionContext *)> + DataFormatSelector; OpRegistrationInfo(); @@ -219,6 +252,8 @@ struct OpRegistrationInfo { std::set devices; std::unordered_map creators; DevicePlacer device_placer; + MemoryTypeSetter memory_type_setter; + DataFormatSelector data_format_selector; }; class OpConditionBuilder { @@ -230,11 +265,18 @@ class OpConditionBuilder { OpConditionBuilder &SetDevicePlacerFunc( OpRegistrationInfo::DevicePlacer placer); + // If you set input memory type for specified Op, + // you must call OpConditionContext::set_output_mem_type + OpConditionBuilder &SetInputMemoryTypeSetter( + OpRegistrationInfo::MemoryTypeSetter setter); + void Finalize(OpRegistrationInfo *info) const; private: std::string type_; OpRegistrationInfo::DevicePlacer placer_; + OpRegistrationInfo::MemoryTypeSetter memory_type_setter_; + OpRegistrationInfo::DataFormatSelector data_format_selector_; }; @@ -250,7 +292,10 @@ class OpRegistryBase { MaceStatus Register(const OpConditionBuilder &builder); const std::set AvailableDevices( - const std::string &op_type, OpConstructContext *context) const; + const std::string &op_type, OpConditionContext *context) const; + + void GetInOutMemoryTypes( + const std::string &op_type, OpConditionContext *context) const; std::unique_ptr CreateOperation( OpConstructContext *context, diff --git a/mace/core/runtime/opencl/opencl_util.cc b/mace/core/runtime/opencl/opencl_util.cc index ca114146..9f9001f3 100644 --- a/mace/core/runtime/opencl/opencl_util.cc +++ b/mace/core/runtime/opencl/opencl_util.cc @@ -147,38 +147,38 @@ void OpenCLUtil::CalImage2DShape(const std::vector &shape, /* NHWC */ } } -std::shared_ptr OpenCLUtil::CreateTransformOpDef( +void OpenCLUtil::BuildTransformOpDef( const std::string &input_name, const std::vector &input_shape, const std::string &output_name, const mace::DataType dt, const OpenCLBufferType buffer_type, const mace::MemoryType mem_type, - bool has_data_format) { - std::unique_ptr op(new OperatorDef); + DataFormat data_format, + OperatorDef *op_def) { std::string op_name = "mace_node_" + output_name; - op->set_name(op_name); - op->set_type("BufferTransform"); - op->add_input(input_name); - op->add_output(output_name); - Argument *arg = op->add_arg(); + op_def->set_name(op_name); + op_def->set_type("BufferTransform"); + op_def->add_input(input_name); + op_def->add_output(output_name); + op_def->set_device_type(DeviceType::GPU); + Argument *arg = op_def->add_arg(); arg->set_name("buffer_type"); arg->set_i(static_cast(buffer_type)); - arg = op->add_arg(); + arg = op_def->add_arg(); arg->set_name("mem_type"); arg->set_i(static_cast(mem_type)); - arg = op->add_arg(); + arg = op_def->add_arg(); arg->set_name("T"); arg->set_i(static_cast(dt)); - arg = op->add_arg(); - arg->set_name("has_data_format"); - arg->set_i(has_data_format); + arg = op_def->add_arg(); + arg->set_name("data_format"); + arg->set_i(data_format); if (!input_shape.empty()) { - OutputShape *shape = op->add_output_shape(); + OutputShape *shape = op_def->add_output_shape(); for (auto value : input_shape) { shape->add_dims(value); } } - return std::move(op); } } // namespace mace diff --git a/mace/core/runtime/opencl/opencl_util.h b/mace/core/runtime/opencl/opencl_util.h index ea0e239e..2d5e2abf 100644 --- a/mace/core/runtime/opencl/opencl_util.h +++ b/mace/core/runtime/opencl/opencl_util.h @@ -43,14 +43,15 @@ class OpenCLUtil { std::vector *image_shape, const int wino_blk_size = 2); - static std::shared_ptr CreateTransformOpDef( + static void BuildTransformOpDef( const std::string &input_name, const std::vector &input_shape, const std::string &output_name, const mace::DataType dt, const OpenCLBufferType buffer_type, const MemoryType mem_type, - bool has_data_format); + DataFormat data_format, + OperatorDef *op_def); }; } // namespace mace diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 7cb97fe7..aa482bee 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -263,13 +263,13 @@ MaceStatus Workspace::PreallocateOutputTensor( } } VLOG(1) << "Preallocate buffer to tensors"; - bool is_quantize_model = IsQuantizedModel(net_def); for (auto &tensor_mem : mem_optimizer->tensor_mem_map()) { std::unique_ptr tensor (new Tensor(preallocated_allocator_.GetBuffer(tensor_mem.second.mem_id), tensor_mem.second.data_type, false, tensor_mem.first)); - if (tensor_mem.second.has_data_format) { + tensor->set_data_format(tensor_mem.second.data_format); + if (tensor_mem.second.data_format != DataFormat::DF_NONE) { if (mem_blocks[tensor_mem.second.mem_id].mem_type() == MemoryType::GPU_IMAGE) { VLOG(1) << "Tensor: " << tensor_mem.first @@ -279,22 +279,12 @@ MaceStatus Workspace::PreallocateOutputTensor( << tensor->UnderlyingBuffer()->shape()[0] << ", " << tensor->UnderlyingBuffer()->shape()[1]; - tensor->set_data_format(DataFormat::NHWC); } else { VLOG(1) << "Tensor: " << tensor_mem.first << " Mem: " << tensor_mem.second.mem_id << " Data type: " << tensor->dtype() << ", Buffer size: " << tensor->UnderlyingBuffer()->size(); - if (mem_blocks[tensor_mem.second.mem_id].mem_type() - == MemoryType::GPU_BUFFER || - is_quantize_model) { - tensor->set_data_format(DataFormat::NHWC); - } else { - tensor->set_data_format(DataFormat::NCHW); - } } - } else { - tensor->set_data_format(DataFormat::DF_NONE); } tensor_map_[tensor_mem.first] = std::move(tensor); } diff --git a/mace/libmace/mace.cc b/mace/libmace/mace.cc index c5e16b76..f00ce2e6 100644 --- a/mace/libmace/mace.cc +++ b/mace/libmace/mace.cc @@ -27,6 +27,7 @@ #include "mace/public/mace.h" #include "mace/port/env.h" #include "mace/port/file_system.h" +#include "mace/core/net_def_adapter.h" #ifdef MACE_ENABLE_OPENCL #include "mace/core/runtime/opencl/gpu_device.h" @@ -512,26 +513,32 @@ MaceStatus MaceEngine::Impl::Init( } } else { #endif - MACE_RETURN_IF_ERROR(ws_->LoadModelTensor(*net_def, - device_.get(), - model_data)); - - MemoryOptimizer mem_optimizer; - // Init model - net_ = std::unique_ptr(new SerialNet(op_registry_.get(), - net_def, - ws_.get(), - device_.get(), - &mem_optimizer)); - - // Preallocate all output tensors of ops - MACE_RETURN_IF_ERROR(ws_->PreallocateOutputTensor(*net_def, - &mem_optimizer, - device_.get())); - if (device_type_ == DeviceType::GPU) { - ws_->RemoveAndReloadBuffer(*net_def, model_data, device_->allocator()); - } - MACE_RETURN_IF_ERROR(net_->Init()); + MACE_RETURN_IF_ERROR(ws_->LoadModelTensor(*net_def, + device_.get(), + model_data)); + + NetDef adapted_net_def; + NetDefAdapter net_def_adapter(op_registry_.get(), ws_.get()); + net_def_adapter.AdaptNetDef(net_def, device_.get(), &adapted_net_def); + + MemoryOptimizer mem_optimizer; + // Init model + net_ = std::unique_ptr(new SerialNet(op_registry_.get(), + &adapted_net_def, + ws_.get(), + device_.get(), + &mem_optimizer)); + + // Preallocate all output tensors of ops + MACE_RETURN_IF_ERROR(ws_->PreallocateOutputTensor(adapted_net_def, + &mem_optimizer, + device_.get())); + if (device_type_ == DeviceType::GPU) { + ws_->RemoveAndReloadBuffer(adapted_net_def, + model_data, + device_->allocator()); + } + MACE_RETURN_IF_ERROR(net_->Init()); #if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA) } #endif diff --git a/mace/ops/activation.cc b/mace/ops/activation.cc index bcdcd8e0..1d697488 100644 --- a/mace/ops/activation.cc +++ b/mace/ops/activation.cc @@ -15,6 +15,8 @@ #include "mace/ops/activation.h" #include +#include + #include "mace/core/operator.h" #if defined(MACE_ENABLE_NEON) @@ -132,6 +134,22 @@ void RegisterActivation(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Activation", ActivationOp, DeviceType::GPU, half); #endif // MACE_ENABLE_OPENCL + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("Activation") + .SetDevicePlacerFunc( + [](OpConditionContext *context) -> std::set { + auto op = context->operator_def(); + int has_data_format = + ProtoArgHelper::GetOptionalArg( + *op, "has_data_format", 0); + if (!has_data_format || + (op->output_shape_size() != op->output_size()) || + op->output_shape(0).dims_size() != 4) { + return { DeviceType::CPU }; + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/addn.cc b/mace/ops/addn.cc index ea6458d4..d5175180 100644 --- a/mace/ops/addn.cc +++ b/mace/ops/addn.cc @@ -103,6 +103,22 @@ void RegisterAddN(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "AddN", AddNOp, DeviceType::GPU, half); #endif // MACE_ENABLE_OPENCL + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("AddN") + .SetDevicePlacerFunc( + [](OpConditionContext *context) -> std::set { + auto op = context->operator_def(); + int has_data_format = + ProtoArgHelper::GetOptionalArg( + *op, "has_data_format", 0); + if (!has_data_format || + (op->output_shape_size() != op->output_size()) || + op->output_shape(0).dims_size() != 4) { + return { DeviceType::CPU }; + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/bias_add.cc b/mace/ops/bias_add.cc index 9351de79..7991a088 100644 --- a/mace/ops/bias_add.cc +++ b/mace/ops/bias_add.cc @@ -145,6 +145,22 @@ void RegisterBiasAdd(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "BiasAdd", BiasAddOp, DeviceType::GPU, half); #endif // MACE_ENABLE_OPENCL + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("BiasAdd") + .SetDevicePlacerFunc( + [](OpConditionContext *context) -> std::set { + auto op = context->operator_def(); + int has_data_format = + ProtoArgHelper::GetOptionalArg( + *op, "has_data_format", 0); + if (!has_data_format || + (op->output_shape_size() != op->output_size()) || + op->output_shape(0).dims_size() != 4) { + return { DeviceType::CPU }; + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/buffer_transform.cc b/mace/ops/buffer_transform.cc index 229d4eb9..f8bf025d 100644 --- a/mace/ops/buffer_transform.cc +++ b/mace/ops/buffer_transform.cc @@ -39,14 +39,14 @@ class BufferTransformOp : public Operation { auto type = static_cast(Operation::GetOptionalArg( "buffer_type", static_cast(CONV2D_FILTER))); - bool has_data_format = Operation::GetOptionalArg("has_data_format", 0) - != 0; + DataFormat data_format = static_cast( + Operation::GetOptionalArg("data_format", DataFormat::DF_NONE)); MemoryType in_mem_type = context->workspace()->GetTensor( operator_def_->input(0))->memory_type(); return OpenCLBufferTransformer(in_mem_type, out_mem_type_).Transform( context, input, type, out_mem_type_, wino_blk_size_, - has_data_format, output); + data_format, output); } private: diff --git a/mace/ops/channel_shuffle.cc b/mace/ops/channel_shuffle.cc index 966b5d57..09811828 100644 --- a/mace/ops/channel_shuffle.cc +++ b/mace/ops/channel_shuffle.cc @@ -116,10 +116,10 @@ void RegisterChannelShuffle(OpRegistryBase *op_registry) { op_registry, OpConditionBuilder("ChannelShuffle") .SetDevicePlacerFunc( - [](OpConstructContext *context) -> std::set { + [](OpConditionContext *context) -> std::set { auto op = context->operator_def(); if (op->output_shape_size() != op->output_size()) { - return { DeviceType::CPU, DeviceType::GPU }; + return { DeviceType::CPU }; } int groups = ProtoArgHelper::GetOptionalArg( *op, "group", 1); diff --git a/mace/ops/concat.cc b/mace/ops/concat.cc index 9fa45feb..d2bb5713 100644 --- a/mace/ops/concat.cc +++ b/mace/ops/concat.cc @@ -241,13 +241,11 @@ void RegisterConcat(OpRegistryBase *op_registry) { op_registry, OpConditionBuilder("Concat") .SetDevicePlacerFunc( - [](OpConstructContext *context) -> std::set { + [](OpConditionContext *context) -> std::set { auto op = context->operator_def(); auto tensor_shape_info = context->tensor_shape_info(); - if (op->output_shape_size() != op->output_size()) { - return { DeviceType::CPU, DeviceType::GPU }; - } - if (op->output_shape(0).dims_size() != 4) { + if (op->output_shape_size() != op->output_size() || + op->output_shape(0).dims_size() != 4) { return { DeviceType::CPU }; } else { int has_data_format = diff --git a/mace/ops/conv_2d.cc b/mace/ops/conv_2d.cc index 5fefeddc..80e8fe78 100644 --- a/mace/ops/conv_2d.cc +++ b/mace/ops/conv_2d.cc @@ -466,7 +466,6 @@ class Conv2dOp : public ConvPool2dOpBase { mem_type = MemoryType::GPU_BUFFER; kernel_ = make_unique>(); } - context->set_output_mem_type(mem_type); // Transform filter tensor to target format if ((wino_block_size_ == 2 || wino_block_size_ == 4) && (kernel_->CheckUseWinograd( diff --git a/mace/ops/crop.cc b/mace/ops/crop.cc index 7265208e..9cb836ee 100644 --- a/mace/ops/crop.cc +++ b/mace/ops/crop.cc @@ -145,6 +145,22 @@ void RegisterCrop(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Crop", CropOp, DeviceType::GPU, half); #endif // MACE_ENABLE_OPENCL + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("Crop") + .SetDevicePlacerFunc( + [](OpConditionContext *context) -> std::set { + auto op = context->operator_def(); + int has_data_format = + ProtoArgHelper::GetOptionalArg( + *op, "has_data_format", 0); + if (!has_data_format || + (op->output_shape_size() != op->output_size()) || + op->output_shape(0).dims_size() != 4) { + return { DeviceType::CPU }; + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/deconv_2d.cc b/mace/ops/deconv_2d.cc index 5692425a..3ac54186 100644 --- a/mace/ops/deconv_2d.cc +++ b/mace/ops/deconv_2d.cc @@ -197,7 +197,6 @@ class Deconv2dOp : public Deconv2dOpBase { OpenCLBufferType::ARGUMENT, mem_type) == MaceStatus::MACE_SUCCESS); } - context->SetInputInfo(2, MemoryType::CPU_BUFFER, DataType::DT_INT32); } } MaceStatus Run(OpContext *context) override { @@ -264,6 +263,30 @@ void RegisterDeconv2D(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Deconv2D", Deconv2dOp, DeviceType::GPU, half); + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("Deconv2D") + .SetInputMemoryTypeSetter( + [](OpConditionContext *context) -> void { + MemoryType mem_type = MemoryType::CPU_BUFFER; + if (context->device()->device_type() == DeviceType::GPU) { + if (context->device()->gpu_runtime()->UseImageMemory()) { + mem_type = MemoryType::GPU_IMAGE; + } else { + MACE_NOT_IMPLEMENTED; + } + FrameworkType framework_type = + static_cast( + ProtoArgHelper::GetOptionalArg( + *(context->operator_def()), "framework_type", + FrameworkType::TENSORFLOW)); + if (framework_type == FrameworkType::TENSORFLOW) { + context->SetInputInfo(2, MemoryType::CPU_BUFFER, + DataType::DT_INT32); + } + } + context->set_output_mem_type(mem_type); + })); #endif // MACE_ENABLE_OPENCL } diff --git a/mace/ops/depthwise_conv2d.cc b/mace/ops/depthwise_conv2d.cc index 67339ef9..7d389766 100644 --- a/mace/ops/depthwise_conv2d.cc +++ b/mace/ops/depthwise_conv2d.cc @@ -382,7 +382,6 @@ class DepthwiseConv2dOp : public DepthwiseConv2dOpBase { mem_type = MemoryType::GPU_BUFFER; kernel_ = make_unique>(); } - context->set_output_mem_type(mem_type); Tensor *filter_tensor = context->workspace()->GetTensor( operator_def_->input(1)); if (filter_tensor != nullptr && filter_tensor->is_weight()) { @@ -393,8 +392,6 @@ class DepthwiseConv2dOp : public DepthwiseConv2dOpBase { 1, OpenCLBufferType::DW_CONV2D_FILTER, mem_type) == MaceStatus::MACE_SUCCESS); - } else { - context->SetInputOpenCLBufferType(1, OpenCLBufferType::DW_CONV2D_FILTER); } if (operator_def_->input_size() > 2) { MACE_CHECK(TransformFilter( @@ -440,6 +437,27 @@ void RegisterDepthwiseConv2d(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "DepthwiseConv2d", DepthwiseConv2dOp, DeviceType::GPU, half); + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("DepthwiseConv2d") + .SetInputMemoryTypeSetter( + [](OpConditionContext *context) -> void { + MemoryType mem_type = MemoryType::CPU_BUFFER; + if (context->device()->device_type() == DeviceType::GPU) { + if (context->device()->gpu_runtime()->UseImageMemory()) { + mem_type = MemoryType::GPU_IMAGE; + } else { + mem_type = MemoryType::GPU_BUFFER; + } + auto filter_tensor = context->workspace()->GetTensor( + context->operator_def()->input(1)); + if (filter_tensor == nullptr || !filter_tensor->is_weight()) { + context->SetInputOpenCLBufferType( + 1, OpenCLBufferType::DW_CONV2D_FILTER); + } + } + context->set_output_mem_type(mem_type); + })); #endif // MACE_ENABLE_OPENCL } diff --git a/mace/ops/expand_dims.cc b/mace/ops/expand_dims.cc index 78fed156..5474dd4b 100644 --- a/mace/ops/expand_dims.cc +++ b/mace/ops/expand_dims.cc @@ -14,7 +14,6 @@ #include "mace/core/operator.h" -#include "mace/ops/common/transpose.h" #include "mace/utils/math.h" namespace mace { @@ -44,27 +43,8 @@ class ExpandDimsOp : public Operation { std::vector output_shape(input_shape); output_shape.insert(output_shape.begin() + axis_, 1); - bool has_data_format = Operation::GetOptionalArg( - "has_data_format", 0) == 1; - if (has_data_format && output_shape.size() == 4) { - // only tensorflow support expand dim, so the default format is NHWC - // transform NHWC to NCHW - auto t_output_shape = TransposeShape(output_shape, - {0, 3, 1, 2}); - output->Resize(t_output_shape); - Tensor::MappingGuard input_guard(input); - Tensor::MappingGuard output_guard(output); - auto input_data = input->data(); - auto output_data = output->mutable_data(); - - Transpose(&context->device()->cpu_runtime()->thread_pool(), - input_data, output_shape, {0, 3, 1, 2}, output_data); - } else { - output->Resize(output_shape); - Tensor::MappingGuard input_guard(input); - auto input_data = input->data(); - output->Copy(input_data, input->size()); - } + output->ReuseTensorBuffer(*input); + output->Reshape(output_shape); return MaceStatus::MACE_SUCCESS; } diff --git a/mace/ops/matmul.cc b/mace/ops/matmul.cc index 65df7305..b662ce2e 100644 --- a/mace/ops/matmul.cc +++ b/mace/ops/matmul.cc @@ -518,14 +518,6 @@ void RegisterMatMul(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp, DeviceType::CPU, uint8_t); #endif // MACE_ENABLE_QUANTIZE - -#ifdef MACE_ENABLE_OPENCL - MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp, - DeviceType::GPU, float); - - MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp, - DeviceType::GPU, half); -#endif // MACE_ENABLE_OPENCL } } // namespace ops diff --git a/mace/ops/opencl/buffer_transformer.h b/mace/ops/opencl/buffer_transformer.h index 20dc6d1a..d5dca3d7 100644 --- a/mace/ops/opencl/buffer_transformer.h +++ b/mace/ops/opencl/buffer_transformer.h @@ -23,7 +23,6 @@ #include "mace/ops/opencl/image/buffer_to_image.h" #include "mace/ops/opencl/image/image_to_buffer.h" #include "mace/ops/opencl/buffer/buffer_transform.h" -#include "mace/ops/common/transpose.h" #include "mace/utils/memory.h" namespace mace { @@ -48,7 +47,7 @@ class OpenCLBufferTransformer { const OpenCLBufferType type, const MemoryType out_mem_type, const int wino_blk_size, - bool has_data_format, + DataFormat data_format, Tensor *output) { Workspace *ws = context->workspace(); DataType dt = DataTypeToEnum::value; @@ -67,31 +66,12 @@ class OpenCLBufferTransformer { VLOG(2) << "Transform CPU Buffer " << input->name() << " to GPU Buffer " << internal_tensor->name() << " with data type " << dt; - if (has_data_format && input->shape().size() == 4) { - // 1. (NCHW -> NHWC) - std::vector dst_dims = {0, 2, 3, 1}; - std::vector output_shape = - TransposeShape(input->shape(), - dst_dims); - internal_tensor->Resize(output_shape); - internal_tensor->set_data_format(DataFormat::NHWC); - // TODO(liuqi): Only support float now - const float *input_ptr = input->data(); - Tensor::MappingGuard guard(internal_tensor); - float *internal_ptr = internal_tensor->mutable_data(); - MACE_RETURN_IF_ERROR(ops::Transpose( - &context->device()->cpu_runtime()->thread_pool(), - input_ptr, - input->shape(), - dst_dims, - internal_ptr)); - } else { - internal_tensor->Resize(input->shape()); - const uint8_t *input_ptr = input->data(); - Tensor::MappingGuard guard(internal_tensor); - uint8_t *internal_ptr = internal_tensor->mutable_data(); - memcpy(internal_ptr, input_ptr, input->raw_size()); - } + MACE_CHECK(data_format == DataFormat::NHWC); + internal_tensor->Resize(input->shape()); + const uint8_t *input_ptr = input->data(); + Tensor::MappingGuard guard(internal_tensor); + uint8_t *internal_ptr = internal_tensor->mutable_data(); + memcpy(internal_ptr, input_ptr, input->raw_size()); // 2. convert the internal GPU Buffer to output. return kernel_->Compute( context, internal_tensor, type, wino_blk_size, output); @@ -108,30 +88,13 @@ class OpenCLBufferTransformer { VLOG(2) << "Transform GPU Buffer " << internal_tensor.name() << " to CPU Buffer " << output->name() << " with data type " << dt; - if (has_data_format && internal_tensor.shape().size() == 4) { - // NHWC -> NCHW - std::vector dst_dims = {0, 3, 1, 2}; - std::vector output_shape = - TransposeShape(internal_tensor.shape(), - dst_dims); - output->set_data_format(DataFormat::NCHW); - Tensor::MappingGuard guard(&internal_tensor); - const float *internal_ptr = internal_tensor.data(); - output->Resize(output_shape); - float *output_ptr = output->mutable_data(); - return ops::Transpose(&context->device()->cpu_runtime()->thread_pool(), - internal_ptr, - internal_tensor.shape(), - dst_dims, - output_ptr); - } else { - Tensor::MappingGuard guard(&internal_tensor); - const T *internal_ptr = internal_tensor.data(); - output->Resize(internal_tensor.shape()); - T *output_ptr = output->mutable_data(); - memcpy(output_ptr, internal_ptr, internal_tensor.size() * sizeof(T)); - return MaceStatus::MACE_SUCCESS; - } + MACE_CHECK(data_format == DataFormat::NHWC); + Tensor::MappingGuard guard(&internal_tensor); + const T *internal_ptr = internal_tensor.data(); + output->Resize(internal_tensor.shape()); + T *output_ptr = output->mutable_data(); + memcpy(output_ptr, internal_ptr, internal_tensor.size() * sizeof(T)); + return MaceStatus::MACE_SUCCESS; } else { LOG(FATAL) << "Unexpected error: " << out_mem_type; return MaceStatus::MACE_SUCCESS; diff --git a/mace/ops/opencl/image/eltwise.h b/mace/ops/opencl/image/eltwise.h index bc1a7025..9c8a1a31 100644 --- a/mace/ops/opencl/image/eltwise.h +++ b/mace/ops/opencl/image/eltwise.h @@ -71,14 +71,17 @@ MaceStatus EltwiseKernel::Compute( if (input1 == nullptr) { input1_type = "INPUT_SCALAR"; } else { - MACE_CHECK(input0->dim_size() == input1->dim_size() || + MACE_CHECK((input0->dim_size() == input1->dim_size() + && input0->dim_size() == 4) || input0->dim_size() == 1 || input1->dim_size() == 1) - << "Inputs of Eltwise op must be same shape"; + << "Inputs of Eltwise op must be same shape or fulfill broadcast logic"; MACE_CHECK(type_ != EltwiseType::EQUAL) << "Eltwise op on GPU does not support EQUAL"; // broadcast - if (input0->size() != input1->size()) { - if (input0->size() < input1->size()) { + if (input0->size() != input1->size() || + input0->dim_size() != input1->dim_size()) { + if (input0->size() < input1->size() + || input0->dim_size() < input1->dim_size()) { std::swap(input0, input1); swapped = true; } diff --git a/mace/ops/opencl/image/reduce.h b/mace/ops/opencl/image/reduce.h index a2bdc652..fa69a116 100644 --- a/mace/ops/opencl/image/reduce.h +++ b/mace/ops/opencl/image/reduce.h @@ -59,11 +59,6 @@ MaceStatus ReduceKernel::Compute( const Tensor *input, Tensor *output) { MACE_CHECK_NOTNULL(input); - MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims."); - MACE_CHECK(input->dim_size() == 4, - "reduce gpu only support 4-dim input"); - MACE_CHECK(axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2, - "reduce gpu only support 1,2-axis reduce"); index_t batch = input->dim(0); const index_t in_height = input->dim(1); const index_t in_width = input->dim(2); diff --git a/mace/ops/pooling.cc b/mace/ops/pooling.cc index 52842c52..21d02e14 100644 --- a/mace/ops/pooling.cc +++ b/mace/ops/pooling.cc @@ -480,7 +480,6 @@ class PoolingOp : public PoolingOpBase { if (context->device()->gpu_runtime()->UseImageMemory()) { kernel_ = make_unique>(); } else { - context->set_output_mem_type(MemoryType::GPU_BUFFER); kernel_ = make_unique>(); } } diff --git a/mace/ops/reduce.cc b/mace/ops/reduce.cc index 29ce821b..86964ed9 100644 --- a/mace/ops/reduce.cc +++ b/mace/ops/reduce.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include "mace/core/future.h" @@ -907,6 +908,31 @@ void RegisterReduce(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, DeviceType::GPU, half); #endif // MACE_ENABLE_OPENCL + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("Reduce") + .SetDevicePlacerFunc( + [](OpConditionContext *context) -> std::set { + auto op = context->operator_def(); + bool keep_dims = + ProtoArgHelper::GetOptionalArg( + *op, "keepdims", false); + if (!keep_dims) { + return { DeviceType::CPU }; + } + auto axis = + ProtoArgHelper::GetRepeatedArgs( + *op, "axis"); + if (axis.size() != 2 || axis[0] != 1 || axis[1] == 2) { + return { DeviceType::CPU }; + } + auto tensor_shape_info = context->tensor_shape_info(); + if (tensor_shape_info->count(op->input(0)) == 0 + || tensor_shape_info->at(op->input(0)).size() != 4) { + return { DeviceType::CPU }; + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/scalar_math.cc b/mace/ops/scalar_math.cc index 5d311cbc..07794065 100644 --- a/mace/ops/scalar_math.cc +++ b/mace/ops/scalar_math.cc @@ -100,11 +100,7 @@ class ScalarMathOp : public Operation { coeff_(Operation::GetRepeatedArgs("coeff")), scalar_input_(Operation::GetOptionalArg("scalar_input", 1.0)), scalar_input_index_(Operation::GetOptionalArg( - "scalar_input_index", 1)) { - if (D == DeviceType::GPU) { - context->set_output_mem_type(MemoryType::GPU_BUFFER); - } - } + "scalar_input_index", 1)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); diff --git a/mace/ops/softmax.cc b/mace/ops/softmax.cc index 0eda5bf3..d5fcbc02 100644 --- a/mace/ops/softmax.cc +++ b/mace/ops/softmax.cc @@ -417,7 +417,6 @@ class SoftmaxOp : public Operation { if (context->device()->gpu_runtime()->UseImageMemory()) { kernel_ = make_unique>(use_log); } else { - context->set_output_mem_type(MemoryType::GPU_BUFFER); kernel_ = make_unique>(use_log); } } @@ -456,7 +455,7 @@ void RegisterSoftmax(OpRegistryBase *op_registry) { op_registry, OpConditionBuilder("Softmax") .SetDevicePlacerFunc( - [](OpConstructContext *context) -> std::set { + [](OpConditionContext *context) -> std::set { auto op = context->operator_def(); if (op->output_shape_size() != op->output_size()) { return { DeviceType::CPU, DeviceType::GPU }; diff --git a/mace/ops/split.cc b/mace/ops/split.cc index e1523a06..6b646270 100644 --- a/mace/ops/split.cc +++ b/mace/ops/split.cc @@ -144,10 +144,10 @@ void RegisterSplit(OpRegistryBase *op_registry) { op_registry, OpConditionBuilder("Split") .SetDevicePlacerFunc( - [](OpConstructContext *context) -> std::set { + [](OpConditionContext *context) -> std::set { auto op = context->operator_def(); if (op->output_shape_size() != op->output_size()) { - return {DeviceType::CPU, DeviceType::GPU}; + return { DeviceType::CPU }; } int axis = ProtoArgHelper::GetOptionalArg( *op, "axis", 3); diff --git a/mace/ops/squeeze.cc b/mace/ops/squeeze.cc index 15c3408c..660a8e8f 100644 --- a/mace/ops/squeeze.cc +++ b/mace/ops/squeeze.cc @@ -77,7 +77,7 @@ void RegisterSqueeze(OpRegistryBase *op_registry) { op_registry, OpConditionBuilder("Squeeze") .SetDevicePlacerFunc( - [](OpConstructContext *context) -> std::set { + [](OpConditionContext *context) -> std::set { auto op = context->operator_def(); if (op->output_shape_size() != op->output_size()) { return { DeviceType::CPU, DeviceType::GPU }; diff --git a/mace/public/mace.h b/mace/public/mace.h index fd39fdba..dd559249 100644 --- a/mace/public/mace.h +++ b/mace/public/mace.h @@ -36,7 +36,8 @@ enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3, HTA = 4 }; enum DataFormat { DF_NONE = 0, NHWC = 1, NCHW = 2, - HWOI = 100, OIHW = 101, HWIO = 102, OHWI = 103 + HWOI = 100, OIHW = 101, HWIO = 102, OHWI = 103, + DF_AUTO = 1000, }; enum GPUPerfHint { diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 80da9b1d..8162f008 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -33,6 +33,7 @@ class DataFormat(Enum): OIHW = 101 HWOI = 102 OHWI = 103 + DF_AUTO = 1000 # SAME_LOWER: if the amount of paddings to be added is odd, @@ -161,13 +162,39 @@ MaceSupportedOps = [ 'SumGroup', 'TargetRMSNorm', 'Transpose', - 'WinogradInverseTransform', - 'WinogradTransform', 'Cumsum', ] MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str) +MaceHasDataFormatOps = [MaceOp.BatchNorm, + MaceOp.BatchToSpaceND, + MaceOp.Conv2D, + MaceOp.Deconv2D, + MaceOp.DepthToSpace, + MaceOp.DepthwiseConv2d, + MaceOp.DepthwiseDeconv2d, + MaceOp.FullyConnected, + MaceOp.Pooling, + MaceOp.ResizeBicubic, + MaceOp.ResizeBilinear, + MaceOp.ResizeNearestNeighbor, + MaceOp.SpaceToBatchND, + MaceOp.SpaceToDepth] + +MaceMayHasDataFormatOps = [MaceOp.Activation, + MaceOp.AddN, + MaceOp.BiasAdd, + MaceOp.ChannelShuffle, + MaceOp.Concat, + MaceOp.Crop, + MaceOp.Eltwise, + MaceOp.Pad, + MaceOp.Reduce, + MaceOp.Softmax, + MaceOp.Split, + MaceOp.SqrDiffMean] + class MaceKeyword(object): # node related str @@ -505,12 +532,11 @@ class ConverterOption(object): TransformerRule.TRANSFORM_CHANNEL_SHUFFLE, # Model data format related transformation TransformerRule.TRANSPOSE_FILTERS, - TransformerRule.TRANSPOSE_DATA_FORMAT, + # Mace model structure related transformation + TransformerRule.ADD_IN_OUT_TENSOR_INFO, TransformerRule.TRANSPOSE_MATMUL_WEIGHT, # Add winograd argument TransformerRule.ADD_WINOGRAD_ARG, - # Mace model structure related transformation - TransformerRule.ADD_IN_OUT_TENSOR_INFO, # Data type related transformation TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE, # Transform finalization @@ -519,6 +545,7 @@ class ConverterOption(object): TransformerRule.SORT_BY_EXECUTION, # update the data format of ops TransformerRule.UPDATE_DATA_FORMAT, + TransformerRule.TRANSPOSE_DATA_FORMAT, # Need to be put after SORT_BY_EXECUTION TransformerRule.ADD_QUANTIZE_TENSOR_RANGE, ] @@ -571,6 +598,8 @@ class ConverterUtil(object): return DataFormat.NHWC elif arg.i == DataFormat.NCHW.value: return DataFormat.NCHW + elif arg.i == DataFormat.DF_AUTO.value: + return DataFormat.DF_AUTO else: return None diff --git a/mace/python/tools/converter_tool/caffe_converter.py b/mace/python/tools/converter_tool/caffe_converter.py index c5b61768..b65a10f4 100644 --- a/mace/python/tools/converter_tool/caffe_converter.py +++ b/mace/python/tools/converter_tool/caffe_converter.py @@ -195,6 +195,7 @@ class CaffeConverter(base_converter.ConverterInterface): self._option = option self._mace_net_def = mace_pb2.NetDef() ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) + ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW) self._caffe_net = CaffeNet() self._caffe_layers = caffe_pb2.NetParameter() caffe_weights = caffe_pb2.NetParameter() diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py index 54d53db0..8974489c 100644 --- a/mace/python/tools/converter_tool/onnx_converter.py +++ b/mace/python/tools/converter_tool/onnx_converter.py @@ -387,6 +387,7 @@ class OnnxConverter(base_converter.ConverterInterface): self._mace_net_def = mace_pb2.NetDef() self._data_format = DataFormat.NCHW ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) + ConverterUtil.add_data_format_arg(self._mace_net_def, self._data_format) onnx_model = onnx.load(src_model_file) ir_version = onnx_model.ir_version diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 58180152..66fef5cb 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -270,6 +270,7 @@ class TensorflowConverter(base_converter.ConverterInterface): self._option = option self._mace_net_def = mace_pb2.NetDef() ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO) + ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC) # import tensorflow graph tf_graph_def = tf.GraphDef() diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index faf33034..65c456c9 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -27,6 +27,8 @@ from mace.python.tools.converter_tool.base_converter import EltwiseType from mace.python.tools.converter_tool.base_converter import FrameworkType from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import MaceOp +from mace.python.tools.converter_tool.base_converter import MaceHasDataFormatOps +from mace.python.tools.converter_tool.base_converter import MaceMayHasDataFormatOps # noqa from mace.python.tools.converter_tool.base_converter import PaddingMode from mace.python.tools.converter_tool.base_converter import ReduceType from mace.python.tools.converter_tool.base_converter import TransformerRule @@ -77,10 +79,9 @@ class Transformer(base_converter.ConverterInterface): self.transpose_matmul_weight, TransformerRule.FOLD_FC_RESHAPE: self.fold_fc_reshape, - TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format, - TransformerRule.ADD_WINOGRAD_ARG: self.add_winograd_arg, TransformerRule.ADD_IN_OUT_TENSOR_INFO: self.add_in_out_tensor_info, + TransformerRule.ADD_WINOGRAD_ARG: self.add_winograd_arg, TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC: self.transform_global_conv_to_fc, TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight, @@ -96,6 +97,7 @@ class Transformer(base_converter.ConverterInterface): self.add_opencl_informations, TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution, TransformerRule.UPDATE_DATA_FORMAT: self.update_data_format, + TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format, TransformerRule.CHECK_QUANTIZE_INFO: self.check_quantize_info, TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN: @@ -194,21 +196,19 @@ class Transformer(base_converter.ConverterInterface): op.type = "Input" data_type_arg = op.arg.add() data_type_arg.name = MaceKeyword.mace_op_data_type_str - data_type_arg.i = mace_pb2.DT_FLOAT + data_type_arg.i = input_node.data_type op.output.extend([input_node.name]) output_shape = op.output_shape.add() output_shape.dims.extend(input_node.shape) - if input_node.name in self._consumers: - if ConverterUtil.data_format( - self._consumers[input_node.name][0]) \ - == DataFormat.NCHW: + if input_node.data_format != DataFormat.DF_NONE: + if input_node.data_format == DataFormat.NCHW: self.transpose_shape(output_shape.dims, [0, 3, 1, 2]) - ConverterUtil.add_data_format_arg(op, - DataFormat.NCHW) - else: - ConverterUtil.add_data_format_arg(op, - DataFormat.NHWC) + ConverterUtil.add_data_format_arg(op, + DataFormat.DF_AUTO) + else: + ConverterUtil.add_data_format_arg(op, + DataFormat.DF_NONE) self._producer[op.output[0]] = op @staticmethod @@ -256,6 +256,13 @@ class Transformer(base_converter.ConverterInterface): else: return None + def get_tensor_data_format(self, tensor): + if tensor in self._producer: + producer = self._producer[tensor] + return ConverterUtil.data_format(producer) + else: + return DataFormat.DF_NONE + def consumer_count(self, tensor_name): return len(self._consumers.get(tensor_name, [])) @@ -838,8 +845,6 @@ class Transformer(base_converter.ConverterInterface): or op.type == MaceOp.DepthwiseConv2d.name or op.type == MaceOp.FullyConnected.name) and len(op.input) == 2) - or (op.type == MaceOp.WinogradInverseTransform.name - and len(op.input) == 1) or (op.type == MaceOp.Deconv2D.name and ((ConverterUtil.get_arg( op, @@ -930,8 +935,7 @@ class Transformer(base_converter.ConverterInterface): or op.type == MaceOp.Deconv2D.name or op.type == MaceOp.DepthwiseConv2d.name or op.type == MaceOp.FullyConnected.name - or op.type == MaceOp.BatchNorm.name - or op.type == MaceOp.WinogradInverseTransform.name) \ + or op.type == MaceOp.BatchNorm.name) \ and len(self._consumers.get(op.output[0], [])) == 1: consumer_op = self._consumers[op.output[0]][0] if consumer_op.type == MaceOp.Activation.name \ @@ -1017,96 +1021,6 @@ class Transformer(base_converter.ConverterInterface): filter_format.name) return False - def transpose_data_format(self): - net = self._model - - for op in net.op: - # transpose args - if op.type == MaceOp.Pad.name: - for arg in op.arg: - if arg.name == MaceKeyword.mace_paddings_str: - mace_check(len(arg.ints) == 8, - "pad dim rank should be 8.") - if ConverterUtil.data_format(op) == DataFormat.NCHW: - print("Transpose pad args: %s(%s)" - % (op.name, op.type)) - self.transpose_shape(arg.ints, - [0, 1, 4, 5, 6, 7, 2, 3]) - elif op.type == MaceOp.Concat.name or op.type == MaceOp.Split.name: - for arg in op.arg: - if arg.name == MaceKeyword.mace_axis_str: - if (ConverterUtil.data_format(op) == DataFormat.NCHW - and len(op.output_shape[0].dims) == 4): - print("Transpose concat/split args: %s(%s)" - % (op.name, op.type)) - if arg.i == 1: - arg.i = 3 - elif arg.i == 2: - arg.i = 1 - elif arg.i == 3: - arg.i = 2 - - producer = self._producer[op.input[0]] - input_shape = producer.output_shape[0].dims - if producer.type == MaceOp.FullyConnected.name and \ - len(input_shape) == 2: - axis_arg = ConverterUtil.get_arg( - op, MaceKeyword.mace_axis_str) - if axis_arg.i == 1: - axis_arg.i = 3 - - elif op.type == MaceOp.Squeeze.name: - for arg in op.arg: - if arg.name == MaceKeyword.mace_axis_str: - if ConverterUtil.data_format(op) == DataFormat.NCHW: - print("Transpose squeeze args: %s(%s)" - % (op.name, op.type)) - mace_check(list(arg.ints) == [2, 3], - 'only support squeeze at at [2, 3]') - arg.ints[:] = [1, 2] - - elif op.type == MaceOp.Reduce.name: - for arg in op.arg: - if arg.name == MaceKeyword.mace_axis_str: - if ConverterUtil.data_format( - op) == DataFormat.NCHW: - print("Transpose reduce args: %s(%s)" - % (op.name, op.type)) - reduce_axises = list(arg.ints) - new_axises = [] - for i in range(len(reduce_axises)): - idx = reduce_axises[i] - if idx == 2 or idx == 3: - new_axises.append(idx - 1) - elif idx == 1: - new_axises.append(3) - else: - new_axises.append(idx) - new_axises.sort() - arg.ints[:] = [] - arg.ints.extend(new_axises) - elif op.type == MaceOp.Crop.name: - offset_arg = ConverterUtil.get_arg(op, - MaceKeyword.mace_offset_str) - mace_check(offset_arg and - ConverterUtil.data_format(op) == DataFormat.NCHW and - len(op.output_shape[0].dims) == 4, - "MACE only support crop with NCHW format") - print("Transpose crop args: %s(%s)" - % (op.name, op.type)) - self.transpose_shape(offset_arg.ints, [0, 2, 3, 1]) - - # transpose op output shape - data_format = ConverterUtil.data_format(op) - if data_format is not None \ - and data_format != DataFormat.NHWC: - print("Transpose output shapes: %s(%s)" % (op.name, op.type)) - for output_shape in op.output_shape: - if len(output_shape.dims) == 4: - self.transpose_shape(output_shape.dims, - [0, 2, 3, 1]) - - return False def add_winograd_arg(self): if self._wino_arg == 0: @@ -1428,17 +1342,121 @@ class Transformer(base_converter.ConverterInterface): def update_data_format(self): print("update data format") - data_format_flag = 1 - for input_node in self._option.input_nodes.values(): - if input_node.data_format.value == DataFormat.DF_NONE.value: - data_format_flag = 0 net = self._model for op in net.op: - ConverterUtil.del_arg( + df_arg = ConverterUtil.get_arg( op, MaceKeyword.mace_data_format_str) - has_data_format_arg = op.arg.add() - has_data_format_arg.name = MaceKeyword.mace_has_data_format_str - has_data_format_arg.i = data_format_flag + if not df_arg: + df_arg = op.arg.add() + df_arg.name = MaceKeyword.mace_data_format_str + if op.type in MaceHasDataFormatOps: + df_arg.i = DataFormat.DF_AUTO.value + elif op.type in MaceMayHasDataFormatOps: + input_df = DataFormat.DF_AUTO.value + for input_tensor in op.input: + if input_tensor in self._consts: + continue + mace_check(input_tensor in self._producer, + "Input tensor %s not in producer" % input_tensor) + father_op = self._producer[input_tensor] + temp_input_df = ConverterUtil.get_arg( + father_op, MaceKeyword.mace_data_format_str) + if temp_input_df.i != DataFormat.DF_AUTO.value: + input_df = temp_input_df.i + if input_df == DataFormat.DF_AUTO.value: + df_arg.i = input_df + # add flag to mark the ops may has data format + has_data_format_arg = op.arg.add() + has_data_format_arg.name = \ + MaceKeyword.mace_has_data_format_str + has_data_format_arg.i = 1 + return False + + def transpose_data_format(self): + print("Transpose arguments based on data format") + net = self._model + + src_data_format = ConverterUtil.data_format(net) + for op in net.op: + has_data_format = ConverterUtil.data_format(op) == \ + DataFormat.DF_AUTO + # transpose args + if op.type == MaceOp.Pad.name: + for arg in op.arg: + if arg.name == MaceKeyword.mace_paddings_str: + mace_check(len(arg.ints) == 8, + "pad dim rank should be 8.") + if src_data_format == DataFormat.NCHW and \ + has_data_format: + print("Transpose pad args: %s(%s)" + % (op.name, op.type)) + self.transpose_shape(arg.ints, + [0, 1, 4, 5, 6, 7, 2, 3]) + elif op.type == MaceOp.Concat.name or op.type == MaceOp.Split.name: + for arg in op.arg: + if arg.name == MaceKeyword.mace_axis_str: + if (src_data_format == DataFormat.NCHW + and has_data_format + and len(op.output_shape[0].dims) == 4): + print("Transpose concat/split args: %s(%s)" + % (op.name, op.type)) + if arg.i == 1: + arg.i = 3 + elif arg.i == 2: + arg.i = 1 + elif arg.i == 3: + arg.i = 2 + + producer = self._producer[op.input[0]] + input_shape = producer.output_shape[0].dims + if producer.type == MaceOp.FullyConnected.name and \ + len(input_shape) == 2: + axis_arg = ConverterUtil.get_arg( + op, MaceKeyword.mace_axis_str) + if axis_arg.i == 1: + axis_arg.i = 3 + + elif op.type == MaceOp.Reduce.name: + for arg in op.arg: + if arg.name == MaceKeyword.mace_axis_str: + if src_data_format == DataFormat.NCHW and \ + has_data_format: + print("Transpose reduce args: %s(%s)" + % (op.name, op.type)) + reduce_axises = list(arg.ints) + new_axises = [] + for i in range(len(reduce_axises)): + idx = reduce_axises[i] + if idx == 2 or idx == 3: + new_axises.append(idx - 1) + elif idx == 1: + new_axises.append(3) + else: + new_axises.append(idx) + new_axises.sort() + arg.ints[:] = [] + arg.ints.extend(new_axises) + elif op.type == MaceOp.Crop.name: + offset_arg = ConverterUtil.get_arg(op, + MaceKeyword.mace_offset_str) + mace_check(offset_arg and + src_data_format == DataFormat.NCHW + and has_data_format + and len(op.output_shape[0].dims) == 4, + "MACE only support crop with NCHW format") + print("Transpose crop args: %s(%s)" + % (op.name, op.type)) + self.transpose_shape(offset_arg.ints, [0, 2, 3, 1]) + + # transpose op output shape + if src_data_format == DataFormat.NCHW and \ + has_data_format: + print("Transpose output shapes: %s(%s)" % (op.name, op.type)) + for output_shape in op.output_shape: + if len(output_shape.dims) == 4: + self.transpose_shape(output_shape.dims, + [0, 2, 3, 1]) + return False def quantize_nodes(self): @@ -1493,7 +1511,7 @@ class Transformer(base_converter.ConverterInterface): self._model.input_info[i].zero_point = quantize_info.zero_point ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8) - ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC) + ConverterUtil.add_data_format_arg(op_def, input_node.data_format) # use actual ranges for model input quantize find_range_every_time_arg = op_def.arg.add() find_range_every_time_arg.name = \ @@ -1516,6 +1534,7 @@ class Transformer(base_converter.ConverterInterface): self._model.output_info[i].zero_point = quantize_info.zero_point ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8) + ConverterUtil.add_data_format_arg(op_def, output_node.data_format) quantize_flag_arg = self._model.arg.add() quantize_flag_arg.name = MaceKeyword.mace_quantize_flag_arg_str @@ -1886,9 +1905,6 @@ class Transformer(base_converter.ConverterInterface): shape_tensor.data_type = mace_pb2.DT_INT32 else: mace_check(False, "Only support reshape and flatten") - # NCHW -> NHWC - if len(dims) == 4: - self.transpose_shape(dims, [0, 2, 3, 1]) shape_tensor.int32_data.extend(dims) op.input.append(shape_tensor.name) @@ -2030,6 +2046,9 @@ class Transformer(base_converter.ConverterInterface): data_type_arg = quantize_op.arg.add() data_type_arg.name = MaceKeyword.mace_op_data_type_str data_type_arg.i = mace_pb2.DT_UINT8 + ConverterUtil.add_data_format_arg( + quantize_op, + self.get_tensor_data_format(input_tensor)) data_type_arg = quantize_op.arg.add() data_type_arg.name = MaceKeyword.mace_non_zero @@ -2050,8 +2069,8 @@ class Transformer(base_converter.ConverterInterface): del op.input[:] op.input.extend(quantized_inputs_names) - orginal_output_name = op.output[0] - op.output[0] = orginal_output_name + "_quant" + original_output_name = op.output[0] + op.output[0] = original_output_name + "_quant" op.output_type.extend([to_quantize_ops_output_type[op.type]]) data_type_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_op_data_type_str) # noqa @@ -2064,13 +2083,15 @@ class Transformer(base_converter.ConverterInterface): dequantize_op.name = op.name + "_dequant" dequantize_op.type = MaceOp.Dequantize.name dequantize_op.input.extend([op.output[0]]) - dequantize_op.output.extend([orginal_output_name]) + dequantize_op.output.extend([original_output_name]) dequantize_op.output_shape.extend(op.output_shape) dequantize_op.output_type.extend([mace_pb2.DT_FLOAT]) data_type_arg = dequantize_op.arg.add() data_type_arg.name = MaceKeyword.mace_op_data_type_str data_type_arg.i = to_quantize_ops_output_type[op.type] - + ConverterUtil.add_data_format_arg( + dequantize_op, + self.get_tensor_data_format(original_output_name)) quantize_flag_arg = ConverterUtil.get_arg(self._model, MaceKeyword.mace_quantize_flag_arg_str) # noqa if quantize_flag_arg is None: -- GitLab