diff --git a/mace/core/net.cc b/mace/core/net.cc index 2aeb951e92fe8a44cb814caf6ab13eaf5c6bae7c..4cdddaac8ed8e5c98d6277593f127279ffe8bf9e 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -64,92 +64,26 @@ bool TransformRequiredOp(const std::string &op_type) { } #endif // MACE_ENABLE_OPENCL - - -// TODO(lichao): Move to runtime driver class after universality done. -// fallback to gpu buffer when kernels are implemented -void FindAvailableDevicesForOp(const OpRegistryBase &op_registry, - const OperatorDef &op, - const std::unordered_map> &tensor_shape_info, - std::set - *available_devices) { - auto devices = op_registry.AvailableDevices(op.type()); - available_devices->insert(devices.begin(), devices.end()); - std::string op_type = op.type(); - // For those whose shape is not 4-rank but can run on GPU - if (op_type == "BufferTransform" - || op_type == "LSTMCell" - || op_type == "FullyConnected" - || op_type == "Softmax" - || op_type == "Squeeze") { - return; - } else { - if (op.output_shape_size() != op.output_size()) { - return; - } - if (op.output_shape(0).dims_size() != 4) { - available_devices->erase(DeviceType::GPU); - } - - if (op_type == "Split") { - if (op.output_shape(0).dims_size() != 4 - || op.output_shape(0).dims()[3] % 4 != 0) { - available_devices->erase(DeviceType::GPU); - } - } else if (op_type == "Concat") { - if (op.output_shape(0).dims_size() != 4) { - available_devices->erase(DeviceType::GPU); - } else { - if (op.input_size() != 2) { - for (const std::string &input : op.input()) { - if (tensor_shape_info.find(input) != tensor_shape_info.end()) { - auto &input_shape = tensor_shape_info.at(input); - if (input_shape[3] % 4 != 0) { - available_devices->erase(DeviceType::GPU); - break; - } - } - } - } - } - } else if (op_type == "ChannelShuffle") { - int groups = ProtoArgHelper::GetOptionalArg( - op, "group", 1); - int channels = op.output_shape(0).dims(3); - int channels_per_group = channels / groups; - if (groups % 4 != 0 || channels_per_group % 4 != 0) { - available_devices->erase(DeviceType::GPU); - } - } - } -} - } // namespace std::unique_ptr SerialNet::CreateOperation( const OpRegistryBase *op_registry, OpConstructContext *construct_context, std::shared_ptr op_def, - const std::unordered_map> tensor_shape_info, DataFormat data_format_flag, bool is_quantize_model) { // Create the Operation DeviceType target_device_type = target_device_->device_type(); - // Get available devices - std::set available_devices; - FindAvailableDevicesForOp(*op_registry, - *op_def, - tensor_shape_info, - &available_devices); - // Find the device type to run the op. - // If the target_device_type in available devices, use target_device_type, - // otherwise, fallback to CPU device. DeviceType device_type = DeviceType::CPU; construct_context->set_device(cpu_device_); construct_context->set_operator_def(op_def); construct_context->set_output_mem_type(MemoryType::CPU_BUFFER); + // Get available devices + auto available_devices = + op_registry->AvailableDevices(op_def->type(), construct_context); + // Find the device type to run the op. + // If the target_device_type in available devices, use target_device_type, + // otherwise, fallback to CPU device. for (auto device : available_devices) { if (device == target_device_type) { device_type = target_device_type; @@ -208,6 +142,23 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, MemoryType target_mem_type; // quantize model flag bool is_quantize_model = IsQuantizedModel(*net_def); + // Tensor Shape map + std::unordered_map> tensor_shape_map; + for (auto &op : net_def->op()) { + if (op.output_size() != op.output_shape_size()) { + continue; + } + for (int i = 0; i < op.output_size(); ++i) { + tensor_shape_map[op.output(i)] = + std::move(std::vector(op.output_shape(i).dims().begin(), + op.output_shape(i).dims().end())); + } + } + for (auto &tensor : net_def->tensors()) { + tensor_shape_map[tensor.name()] = + std::move(std::vector(tensor.dims().begin(), + tensor.dims().end())); + } DataFormat data_format_flag = NHWC; if (target_device_->device_type() == DeviceType::CPU) { @@ -216,11 +167,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, std::vector input_shape = std::vector(input_info.dims().begin(), input_info.dims().end()); + // update tensor shape map + tensor_shape_map[input_info.name()] = input_shape; // Only could be NONE or NHWC auto input_data_format = static_cast( input_info.data_format()); - if (!is_quantize_model && - input_data_format == NHWC && + if (!is_quantize_model && input_data_format == NHWC && input_info.dims_size() == 4) { // NHWC -> NCHW input_shape = @@ -237,39 +189,29 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, else { // GPU NOLINT[readability/braces] target_mem_type = MemoryType::GPU_BUFFER; for (auto &input_info : net_def->input_info()) { + auto input_data_format = static_cast( + input_info.data_format()); + if (input_data_format == DataFormat::DF_NONE) { + data_format_flag = DataFormat::DF_NONE; + } std::vector input_shape = std::vector(input_info.dims().begin(), input_info.dims().end()); + // update tensor shape map + tensor_shape_map[input_info.name()] = input_shape; output_map.emplace(input_info.name(), InternalOutputInfo( target_mem_type, DataType::DT_FLOAT, input_shape, -1)); } } #endif // MACE_ENABLE_OPENCL - std::unordered_map> tensor_shape_info; - for (auto &op : net_def->op()) { - if (op.output_size() != op.output_shape_size()) { - continue; - } - for (int i = 0; i < op.output_size(); ++i) { - tensor_shape_info[op.output(i)] = - std::move(std::vector(op.output_shape(i).dims().begin(), - op.output_shape(i).dims().end())); - } - } - for (auto &tensor : net_def->tensors()) { - tensor_shape_info[tensor.name()] = - std::move(std::vector(tensor.dims().begin(), - tensor.dims().end())); - } - OpConstructContext construct_context(ws_); + OpConstructContext construct_context(ws_, &tensor_shape_map); for (int idx = 0; idx < net_def->op_size(); ++idx) { std::shared_ptr op_def(new OperatorDef(net_def->op(idx))); // Create operation auto op = CreateOperation(op_registry, &construct_context, op_def, - tensor_shape_info, data_format_flag, is_quantize_model); #ifdef MACE_ENABLE_OPENCL @@ -317,12 +259,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, } auto transform_op_def = OpenCLUtil::CreateTransformOpDef( input_name, input_shape, t_input_name, - wanted_in_dt, wanted_in_mem_type); + wanted_in_dt, wanted_in_mem_type, data_format_flag); + OpConstructContext t_construct_context(ws_); auto transform_op = CreateOperation( op_registry, - &construct_context, + &t_construct_context, transform_op_def, - tensor_shape_info, data_format_flag); operators_.emplace_back(std::move(transform_op)); transformed_set.insert(t_input_name); @@ -405,12 +347,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, internal_output_info.shape, output_info.name(), output_info.data_type(), - target_mem_type); + target_mem_type, + data_format_flag); auto transform_op = CreateOperation( op_registry, &construct_context, transform_op_def, - tensor_shape_info, output_data_format); operators_.emplace_back(std::move(transform_op)); // where to do graph reference count. diff --git a/mace/core/net.h b/mace/core/net.h index 5362d9ee4b8630a894cbb5705d6503bce2ed85f2..10577a572f5a0629ae515d9b330befbaa639016e 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -59,8 +59,6 @@ class SerialNet : public NetBase { const OpRegistryBase *op_registry, OpConstructContext *construct_context, std::shared_ptr op_def, - const std::unordered_map> tensor_shape_info, DataFormat input_format, bool is_quantize_model = false); diff --git a/mace/core/operator.cc b/mace/core/operator.cc index ad88c35b2d0bc0b5a216148084783cc5941cf9d1..4fd23db40e4a7ec419cbc2abeaa4f0cf8a198a24 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -22,7 +22,18 @@ namespace mace { OpConstructContext::OpConstructContext(Workspace *ws) - : operator_def_(nullptr), ws_(ws), device_(nullptr) {} + : operator_def_(nullptr), + ws_(ws), + device_(nullptr), + tensor_shape_info_(nullptr) {} + +OpConstructContext::OpConstructContext( + mace::Workspace *ws, + mace::OpConstructContext::TensorShapeMap *info) + : operator_def_(nullptr), + ws_(ws), + device_(nullptr), + tensor_shape_info_(info) {} void OpConstructContext::set_operator_def( std::shared_ptr operator_def) { @@ -169,6 +180,19 @@ const std::string OpKeyBuilder::Build() { } } // namespace +OpRegistrationInfo::OpRegistrationInfo() { + device_placer = [this](OpConstructContext *context) -> std::set { + auto op = context->operator_def(); + // The GPU ops only support 4D In/Out tensor by default + if (this->devices.count(DeviceType::CPU) == 1 && + op->output_shape_size() == op->output_size() && + op->output_shape(0).dims_size() != 4) { + return { DeviceType::CPU }; + } + return this->devices; + }; +} + void OpRegistrationInfo::AddDevice(mace::DeviceType device) { devices.insert(device); } @@ -179,10 +203,11 @@ void OpRegistrationInfo::Register(const std::string &key, OpCreator creator) { creators[key] = creator; } -MaceStatus OpRegistryBase::Register(const std::string &op_type, - const mace::DeviceType device_type, - const mace::DataType dt, - mace::OpRegistrationInfo::OpCreator creator) { +MaceStatus OpRegistryBase::Register( + const std::string &op_type, + const mace::DeviceType device_type, + const mace::DataType dt, + mace::OpRegistrationInfo::OpCreator creator) { if (registry_.count(op_type) == 0) { registry_[op_type] = std::unique_ptr( new OpRegistrationInfo); @@ -197,15 +222,25 @@ MaceStatus OpRegistryBase::Register(const std::string &op_type, return MaceStatus::MACE_SUCCESS; } +MaceStatus OpRegistryBase::Register( + const OpConditionBuilder &builder) { + std::string op_type = builder.type(); + if (registry_.count(op_type) == 0) { + registry_[op_type] = std::unique_ptr( + new OpRegistrationInfo); + } + builder.Finalize(registry_[op_type].get()); + return MaceStatus::MACE_SUCCESS; +} + const std::set OpRegistryBase::AvailableDevices( - const std::string &op_type) const { + const std::string &op_type, OpConstructContext *context) const { MACE_CHECK(registry_.count(op_type) != 0, op_type, " operation is not registered."); - return registry_.at(op_type)->devices; + return registry_.at(op_type)->device_placer(context); } - std::unique_ptr OpRegistryBase::CreateOperation( OpConstructContext *context, DeviceType device_type) const { @@ -238,4 +273,24 @@ std::unique_ptr OpRegistryBase::CreateOperation( } return registry_.at(op_type)->creators.at(key)(context); } + +OpConditionBuilder::OpConditionBuilder(const std::string &type) + : type_(type) {} + +const std::string OpConditionBuilder::type() const { + return type_; +} + +OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc( + OpRegistrationInfo::DevicePlacer placer) { + placer_ = placer; + return *this; +} + +void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const { + if (info != nullptr && placer_) { + info->device_placer = placer_; + } +} + } // namespace mace diff --git a/mace/core/operator.h b/mace/core/operator.h index 5a119d1ee0cde520ac1820117080c7d0a19bc52b..e1e35f35e161ba26f4d345d1124985816224e5a7 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -31,8 +31,11 @@ namespace mace { // memory_optimizer, device class OpConstructContext { + typedef std::unordered_map> TensorShapeMap; + public: explicit OpConstructContext(Workspace *ws); + OpConstructContext(Workspace *ws, TensorShapeMap *info); ~OpConstructContext() = default; void set_operator_def(std::shared_ptr operator_def); @@ -53,6 +56,10 @@ class OpConstructContext { return device_; } + inline TensorShapeMap *tensor_shape_info() const { + return tensor_shape_info_; + } + void set_output_mem_type(MemoryType type); inline MemoryType output_mem_type() const { @@ -69,6 +76,7 @@ class OpConstructContext { std::shared_ptr operator_def_; Workspace *ws_; Device *device_; + TensorShapeMap *tensor_shape_info_; // used for memory transform std::vector input_mem_types_; std::vector input_data_types_; @@ -188,8 +196,10 @@ struct OpRegistrationInfo { public: typedef std::function(OpConstructContext *)> OpCreator; + typedef std::function(OpConstructContext *)> + DevicePlacer; - OpRegistrationInfo() = default; + OpRegistrationInfo(); void AddDevice(DeviceType); @@ -197,8 +207,26 @@ struct OpRegistrationInfo { std::set devices; std::unordered_map creators; + DevicePlacer device_placer; +}; + +class OpConditionBuilder { + public: + explicit OpConditionBuilder(const std::string &type); + + const std::string type() const; + + OpConditionBuilder &SetDevicePlacerFunc( + OpRegistrationInfo::DevicePlacer placer); + + void Finalize(OpRegistrationInfo *info) const; + + private: + std::string type_; + OpRegistrationInfo::DevicePlacer placer_; }; + class OpRegistryBase { public: OpRegistryBase() = default; @@ -208,8 +236,10 @@ class OpRegistryBase { const DataType dt, OpRegistrationInfo::OpCreator creator); + MaceStatus Register(const OpConditionBuilder &builder); + const std::set AvailableDevices( - const std::string &op_type) const; + const std::string &op_type, OpConstructContext *context) const; std::unique_ptr CreateOperation( OpConstructContext *context, @@ -234,6 +264,9 @@ class OpRegistryBase { DataTypeToEnum
::value, \ OpRegistryBase::DefaultCreator>) +#define MACE_REGISTER_OP_CONDITION(op_registry, builder) \ + op_registry->Register(builder) + } // namespace mace #endif // MACE_CORE_OPERATOR_H_ diff --git a/mace/core/runtime/opencl/opencl_util.cc b/mace/core/runtime/opencl/opencl_util.cc index 02ffc8e02222492e9ec9f8d7a0688c9e3c49c5e7..ca9e63dd70d04f36bb81a0d0fc2f0d344a558b72 100644 --- a/mace/core/runtime/opencl/opencl_util.cc +++ b/mace/core/runtime/opencl/opencl_util.cc @@ -151,7 +151,8 @@ std::shared_ptr OpenCLUtil::CreateTransformOpDef( const std::vector &input_shape, const std::string &output_name, const mace::DataType dt, - const mace::MemoryType mem_type) { + const mace::MemoryType mem_type, + const DataFormat data_format) { std::unique_ptr op(new OperatorDef); std::string op_name = "mace_node_" + output_name; op->set_name(op_name); @@ -168,8 +169,8 @@ std::shared_ptr OpenCLUtil::CreateTransformOpDef( arg->set_name("T"); arg->set_i(static_cast(dt)); arg = op->add_arg(); - arg->set_name("device"); - arg->set_i(DeviceType::GPU); + arg->set_name("data_format"); + arg->set_i(data_format); if (!input_shape.empty()) { OutputShape *shape = op->add_output_shape(); for (auto value : input_shape) { diff --git a/mace/core/runtime/opencl/opencl_util.h b/mace/core/runtime/opencl/opencl_util.h index eb518317455dccebb6e05a7456765fbd0700f566..5449a8e1ee3d00eab041e0ee7bbd650627cfd909 100644 --- a/mace/core/runtime/opencl/opencl_util.h +++ b/mace/core/runtime/opencl/opencl_util.h @@ -20,6 +20,7 @@ #include #include "mace/core/types.h" +#include "mace/public/mace.h" namespace mace { enum OpenCLBufferType { @@ -47,7 +48,8 @@ class OpenCLUtil { const std::vector &input_shape, const std::string &output_name, const mace::DataType dt, - const MemoryType mem_type); + const MemoryType mem_type, + const DataFormat data_format); }; } // namespace mace diff --git a/mace/ops/buffer_to_image_benchmark.cc b/mace/ops/buffer_to_image_benchmark.cc index 4ba0f64c1ce2354f3e8b133664303dff59896a07..2097e34e7ea0ef5b24b20fcda30d66b7de07983a 100644 --- a/mace/ops/buffer_to_image_benchmark.cc +++ b/mace/ops/buffer_to_image_benchmark.cc @@ -46,6 +46,7 @@ void FilterBufferToImage(int iters, OpenCLBufferType::IN_OUT_CHANNEL, MemoryType::GPU_IMAGE, 0, + DataFormat::NHWC, b2i_output); }; diff --git a/mace/ops/buffer_to_image_test.cc b/mace/ops/buffer_to_image_test.cc index e6a65aa258fa8c76328c5be88a99e04e0bb1f074..dcd5569102552116fc0029bd669d569b0eef90d7 100644 --- a/mace/ops/buffer_to_image_test.cc +++ b/mace/ops/buffer_to_image_test.cc @@ -35,14 +35,14 @@ void TestBidirectionTransform(const OpenCLBufferType type, OpenCLBufferTransformer(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE) .Transform(&context, net.ws()->GetTensor("Input"), - type, MemoryType::GPU_IMAGE, 0, b2i_output); + type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output); // Inverse Transform Tensor *i2b_output = net.ws()->CreateTensor( "I2BOutput", context.device()->allocator(), DataTypeToEnum::value); OpenCLBufferTransformer(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER) .Transform(&context, b2i_output, - type, MemoryType::GPU_BUFFER, 0, i2b_output); + type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output); // Check ExpectTensorNear(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), @@ -176,14 +176,14 @@ void TestDiffTypeBidirectionTransform(const OpenCLBufferType type, OpenCLBufferTransformer(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE) .Transform(&context, net.ws()->GetTensor("Input"), - type, MemoryType::GPU_IMAGE, 0, b2i_output); + type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output); // Inverse Transform Tensor *i2b_output = net.ws()->CreateTensor( "I2BOutput", context.device()->allocator(), DT_FLOAT); OpenCLBufferTransformer(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER) .Transform(&context, b2i_output, - type, MemoryType::GPU_BUFFER, 0, i2b_output); + type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output); // Check ExpectTensorNear(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), @@ -216,14 +216,14 @@ void TestStringHalfBidirectionTransform(const OpenCLBufferType type, // Transform OpenCLBufferTransformer(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE) .Transform(&context, net.ws()->GetTensor("Input"), - type, MemoryType::GPU_IMAGE, 0, b2i_output); + type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output); // Inverse Transform Tensor *i2b_output = net.ws()->CreateTensor( "I2BOutput", context.device()->allocator(), DataTypeToEnum::value); OpenCLBufferTransformer(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER) .Transform(&context, b2i_output, - type, MemoryType::GPU_BUFFER, 0, i2b_output); + type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output); // Check ExpectTensorNear(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), diff --git a/mace/ops/buffer_transform.cc b/mace/ops/buffer_transform.cc index 088b149d894b76439b32087134d3958da4ad0af4..d2c1f3e795ff8fe07e20380dfcf5d7c3ab04b828 100644 --- a/mace/ops/buffer_transform.cc +++ b/mace/ops/buffer_transform.cc @@ -39,11 +39,14 @@ class BufferTransformOp : public Operation { auto type = static_cast(Operation::GetOptionalArg( "buffer_type", static_cast(CONV2D_FILTER))); + auto data_format = static_cast(Operation::GetOptionalArg( + "data_format", DataFormat::DF_NONE)); MemoryType in_mem_type = context->workspace()->GetTensor( operator_def_->input(0))->memory_type(); return OpenCLBufferTransformer(in_mem_type, out_mem_type_).Transform( - context, input, type, out_mem_type_, wino_blk_size_, output); + context, input, type, out_mem_type_, wino_blk_size_, + data_format, output); } private: diff --git a/mace/ops/buffer_transform_test.cc b/mace/ops/buffer_transform_test.cc index c18e81cf99f4b8d6d1fef29ba3d95aa8873292f2..53be39e4a2764d8845cc822de0730c79f37fbfd9 100644 --- a/mace/ops/buffer_transform_test.cc +++ b/mace/ops/buffer_transform_test.cc @@ -46,7 +46,7 @@ void TestBidirectionTransform(const OpenCLBufferType type, OpenCLBufferTransformer(MemoryType::GPU_BUFFER, MemoryType::GPU_BUFFER) .Transform(&context, net.ws()->GetTensor("Input"), - type, MemoryType::GPU_BUFFER, 0, bt_output); + type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, bt_output); // Inverse Transform Tensor *output = net.ws()->CreateTensor( @@ -55,7 +55,7 @@ void TestBidirectionTransform(const OpenCLBufferType type, OpenCLBufferTransformer(MemoryType::GPU_BUFFER, MemoryType::GPU_BUFFER) .Transform(&context, bt_output, - type, MemoryType::GPU_BUFFER, 0, output); + type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, output); if (DataTypeToEnum::value == DataTypeToEnum::value) { EXPECT_EQ(net.GetOutput("Input")->UnderlyingBuffer(), @@ -92,7 +92,7 @@ void TestArgumentTransform(const index_t input_size) { MemoryType::GPU_BUFFER) .Transform(&context, net.ws()->GetTensor("Input"), OpenCLBufferType::ARGUMENT, MemoryType::GPU_BUFFER, - 0, output); + 0, DataFormat::NHWC, output); index_t expected_size = RoundUp(input_size, 4); EXPECT_EQ(expected_size, output->buffer_shape()[0]); diff --git a/mace/ops/channel_shuffle.cc b/mace/ops/channel_shuffle.cc index 8301ccb54681bcf5fc1e521ec603ede8fc2d205f..4f8a6f9a03a718ba5cab61d45cd2e11ffab1b2c0 100644 --- a/mace/ops/channel_shuffle.cc +++ b/mace/ops/channel_shuffle.cc @@ -111,6 +111,28 @@ void RegisterChannelShuffle(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "ChannelShuffle", ChannelShuffleOp, DeviceType::GPU, half); #endif // MACE_ENABLE_OPENCL + + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("ChannelShuffle") + .SetDevicePlacerFunc( + [](OpConstructContext *context) -> std::set { + auto op = context->operator_def(); + if (op->output_shape_size() != op->output_size()) { + return { DeviceType::CPU, DeviceType::GPU }; + } + int groups = ProtoArgHelper::GetOptionalArg( + *op, "group", 1); + if (op->output_shape(0).dims_size() != 4) { + return { DeviceType::CPU }; + } + index_t channels = op->output_shape(0).dims(3); + index_t channels_per_group = channels / groups; + if (groups % 4 != 0 || channels_per_group % 4 != 0) { + return { DeviceType::CPU }; + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/concat.cc b/mace/ops/concat.cc index 3fa5ef2c5097e9c2a38f68fac1707a46bb440777..ae4f23e04a818f104e48a186c9350aadc428ddd6 100644 --- a/mace/ops/concat.cc +++ b/mace/ops/concat.cc @@ -59,7 +59,9 @@ class ConcatOp : public ConcatOpBase { MACE_UNUSED(context); if (!checked_) { Validate(); - if (this->Input(0)->dim_size() == 4) { + auto df = static_cast(Operation::GetOptionalArg( + "data_format", DataFormat::DF_NONE)); + if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) { if (axis_ == 3) axis_ = 1; else if (axis_ == 2) axis_ = 3; else if (axis_ == 1) axis_ = 2; @@ -232,7 +234,42 @@ void RegisterConcat(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Concat", ConcatOp, DeviceType::GPU, half); + #endif // MACE_ENABLE_OPENCL + + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("Concat") + .SetDevicePlacerFunc( + [](OpConstructContext *context) -> std::set { + auto op = context->operator_def(); + auto tensor_shape_info = context->tensor_shape_info(); + if (op->output_shape_size() != op->output_size()) { + return { DeviceType::CPU, DeviceType::GPU }; + } + if (op->output_shape(0).dims_size() != 4) { + return { DeviceType::CPU }; + } else { + int axis = ProtoArgHelper::GetOptionalArg( + *op, "axis", 3); + if (axis != 3) { + return { DeviceType::CPU }; + } + bool divisible_four = true; + for (const std::string &input : op->input()) { + if (tensor_shape_info->find(input) + != tensor_shape_info->end()) { + divisible_four = divisible_four + && (tensor_shape_info->at(input)[3] % 4 == 0); + } + } + // Only support not divisible 4 case with 2 inputs. + if (op->input_size() > 2 && !divisible_four) { + return { DeviceType::CPU }; + } + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/concat_test.cc b/mace/ops/concat_test.cc index a1b38898f9cf05919edf4433a7d502d3ae1626c7..9eb463badd34dd593cac9f190557a5e5e83ed80b 100644 --- a/mace/ops/concat_test.cc +++ b/mace/ops/concat_test.cc @@ -186,6 +186,7 @@ TEST_F(ConcatOpTest, QuantizedCPURandom) { builder = builder.Input(MakeString("Input", i)); } builder.AddIntArg("axis", axis_arg) + .AddIntArg("data_format", DataFormat::NHWC) .Output("Output") .Finalize(net.NewOperatorDef()); @@ -245,8 +246,9 @@ TEST_F(ConcatOpTest, QuantizedCPURandom) { namespace { template -void OpenclRandomTest(const std::vector> &shapes, - const int axis) { +void OpenCLRandomTest(const std::vector> &shapes, + const int axis, + DataFormat data_format) { srand(time(nullptr)); int num_inputs = shapes.size(); int concat_axis_size = 0; @@ -262,6 +264,8 @@ void OpenclRandomTest(const std::vector> &shapes, net.AddInputFromArray(input_name, shapes[i], inputs[i]); } + std::vector expected_shape = shapes[0]; + expected_shape[axis] = concat_axis_size; auto builder = OpDefBuilder("Concat", "ConcatTest"); for (int i = 0; i < num_inputs; ++i) { @@ -271,6 +275,8 @@ void OpenclRandomTest(const std::vector> &shapes, builder.AddIntArg("axis", axis) .Output("Output") .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .AddIntArg("data_format", data_format) + .OutputShape(expected_shape) .Finalize(net.NewOperatorDef()); // Run @@ -279,8 +285,6 @@ void OpenclRandomTest(const std::vector> &shapes, // Check auto output = net.GetOutput("Output"); - std::vector expected_shape = shapes[0]; - expected_shape[axis] = concat_axis_size; EXPECT_THAT(output->shape(), ::testing::ContainerEq(expected_shape)); Tensor::MappingGuard output_mapper(output); @@ -305,20 +309,38 @@ void OpenclRandomTest(const std::vector> &shapes, } // namespace TEST_F(ConcatOpTest, OPENCLAligned) { - OpenclRandomTest({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3); + OpenCLRandomTest({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3, + DataFormat::NHWC); } TEST_F(ConcatOpTest, OPENCLHalfAligned) { - OpenclRandomTest({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3); + OpenCLRandomTest({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3, + DataFormat::NHWC); } TEST_F(ConcatOpTest, OPENCLUnAligned) { - OpenclRandomTest({{3, 32, 32, 13}, {3, 32, 32, 17}}, 3); + OpenCLRandomTest({{3, 32, 32, 13}, {3, 32, 32, 17}}, 3, + DataFormat::NHWC); } TEST_F(ConcatOpTest, OPENCLAlignedMultiInput) { - OpenclRandomTest( - {{3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}}, 3); + OpenCLRandomTest( + {{3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}}, + 3, DataFormat::NHWC); +} + +TEST_F(ConcatOpTest, GPUFallbackToCPU2DInput) { + OpenCLRandomTest({{3, 4}, {3, 4}}, 1, DataFormat::DF_NONE); +} + +TEST_F(ConcatOpTest, GPUFallbackToCPUChanNotDivisibleBy4) { + OpenCLRandomTest({{1, 1, 4, 3}, {1, 1, 4, 3}}, 3, + DataFormat::DF_NONE); +} + +TEST_F(ConcatOpTest, GPUFallbackToCPUAxis2) { + OpenCLRandomTest({{1, 1, 4, 3}, {1, 1, 4, 3}}, 2, + DataFormat::DF_NONE); } } // namespace test diff --git a/mace/ops/opencl/buffer/buffer_transform.h b/mace/ops/opencl/buffer/buffer_transform.h index 7f9eae2125be87790151a26f404cb4119890ecd2..09038553f1758fd288a034184f8d15519bc6b003 100644 --- a/mace/ops/opencl/buffer/buffer_transform.h +++ b/mace/ops/opencl/buffer/buffer_transform.h @@ -78,7 +78,6 @@ MaceStatus BufferTransform::Compute(OpContext *context, const OpenCLBufferType type, const int wino_blk_size, Tensor *output) { - MACE_UNUSED(type); MACE_UNUSED(wino_blk_size); const DataType dt = DataTypeToEnum::value; switch (type) { @@ -92,8 +91,8 @@ MaceStatus BufferTransform::Compute(OpContext *context, if (input->dtype() != dt) { return BufferTypeTransform(context, &kernel_, input, dt, output); } else { - LOG(FATAL) << "Should not reach here. " << input->name() - << "<" << type << "> to " << output->name(); + SetFutureDefaultWaitFn(context->future()); + output->ReuseTensorBuffer(*input); return MaceStatus::MACE_SUCCESS; } } diff --git a/mace/ops/opencl/buffer_transformer.h b/mace/ops/opencl/buffer_transformer.h index 7acc39a90d7ffb7c89f7d3407402cd27ab19efb6..7279b48f0882527f5f19585ee4de0bd32d2a8b8d 100644 --- a/mace/ops/opencl/buffer_transformer.h +++ b/mace/ops/opencl/buffer_transformer.h @@ -47,6 +47,7 @@ class OpenCLBufferTransformer { const OpenCLBufferType type, const MemoryType out_mem_type, const int wino_blk_size, + const DataFormat data_format, Tensor *output) { Workspace *ws = context->workspace(); DataType dt = DataTypeToEnum::value; @@ -65,7 +66,7 @@ class OpenCLBufferTransformer { VLOG(2) << "Transform CPU Buffer " << input->name() << " to GPU Buffer " << internal_tensor->name() << " with data type " << dt; - if (input->shape().size() == 4) { + if (data_format == DataFormat::NHWC && input->shape().size() == 4) { // 1. (NCHW -> NHWC) std::vector dst_dims = {0, 2, 3, 1}; std::vector output_shape = @@ -103,7 +104,8 @@ class OpenCLBufferTransformer { VLOG(2) << "Transform GPU Buffer " << internal_tensor.name() << " to CPU Buffer " << output->name() << " with data type " << dt; - if (internal_tensor.shape().size() == 4) { + if (data_format == DataFormat::NHWC && + internal_tensor.shape().size() == 4) { // NHWC -> NCHW std::vector dst_dims = {0, 3, 1, 2}; std::vector output_shape = @@ -165,7 +167,7 @@ MaceStatus TransformFilter( input->MarkUnused(); return OpenCLBufferTransformer(input->memory_type(), mem_type). Transform(&op_context, input, buffer_type, mem_type, wino_blk_size, - output); + DataFormat::DF_NONE, output); } } // namespace ops diff --git a/mace/ops/opencl/image/channel_shuffle.h b/mace/ops/opencl/image/channel_shuffle.h index f890c0c3309988cad9acc380560c3358f736e775..8c3a19e0cad726bd2d4a01d02a19680f7b2cc08e 100644 --- a/mace/ops/opencl/image/channel_shuffle.h +++ b/mace/ops/opencl/image/channel_shuffle.h @@ -61,9 +61,6 @@ MaceStatus ChannelShuffleKernel::Compute( const index_t width = input->dim(2); const index_t channels = input->dim(3); const index_t channels_per_group = channels / groups_; - MACE_CHECK(channels_per_group % 4 == 0, - "channels per group must be multiple of 4"); - MACE_CHECK(groups_ % 4 == 0, "groups must be multiple of 4"); const index_t group_channel_blocks = RoundUpDiv4(channels_per_group); const uint32_t gws[3] = {static_cast(group_channel_blocks), diff --git a/mace/ops/opencl/image/concat.h b/mace/ops/opencl/image/concat.h index c7f5e099168f43182cdb9e7bb39ac9df0dbdaeb6..d92f9a7be6e560957af570f1236cc87250fdd60b 100644 --- a/mace/ops/opencl/image/concat.h +++ b/mace/ops/opencl/image/concat.h @@ -67,18 +67,14 @@ MaceStatus ConcatKernel::Compute( const std::vector &input_list, Tensor *output) { const int inputs_count = input_list.size(); - MACE_CHECK(inputs_count >= 2 && axis_ == 3) - << "Concat opencl kernel only support >=2 elements with axis == 3"; const Tensor *input0 = input_list[0]; - bool divisible_four = input0->dim(axis_) % 4 == 0; std::vector output_shape(input0->shape()); for (int i = 1; i < inputs_count; ++i) { const Tensor *input = input_list[i]; MACE_CHECK(input->dim_size() == input0->dim_size(), "Ranks of all input tensors must be same."); - divisible_four &= input->dim(axis_) % 4 == 0; for (int j = 0; j < input->dim_size(); ++j) { if (j == axis_) { continue; @@ -88,9 +84,6 @@ MaceStatus ConcatKernel::Compute( } output_shape[axis_] += input->dim(axis_); } - MACE_CHECK( - inputs_count == 2 || divisible_four, - "Dimensions of inputs should be divisible by 4 when inputs_count > 2."); std::vector image_shape; OpenCLUtil::CalImage2DShape(output_shape, OpenCLBufferType::IN_OUT_CHANNEL, @@ -103,15 +96,9 @@ MaceStatus ConcatKernel::Compute( context, &kernel_, input_list[0], input_list[1], DataTypeToEnum::value, &input_shape_, output, &kwg_size_); default: - if (divisible_four) { - return concat::ConcatN(context, &kernel_, input_list, - DataTypeToEnum::value, output, &kwg_size_); - } else { - MACE_NOT_IMPLEMENTED; - } + return concat::ConcatN(context, &kernel_, input_list, + DataTypeToEnum::value, output, &kwg_size_); } - - return MaceStatus::MACE_SUCCESS; } } // namespace image diff --git a/mace/ops/opencl/image/split.h b/mace/ops/opencl/image/split.h index d0427a4f16ce5b18d37c09ce274e9d1fd621661e..388a55531d8f01c7fd20a5e7ffd48f1718f57555 100644 --- a/mace/ops/opencl/image/split.h +++ b/mace/ops/opencl/image/split.h @@ -34,9 +34,7 @@ namespace image { template class SplitKernel : public OpenCLSplitKernel { public: - explicit SplitKernel(const int32_t axis) : axis_(axis) { - MACE_CHECK(axis == 3) << "GPU only support channel-dimension split"; - } + explicit SplitKernel(const int32_t axis) : axis_(axis) {} MaceStatus Compute( OpContext *context, const Tensor *input, @@ -56,8 +54,6 @@ MaceStatus SplitKernel::Compute( const index_t input_channels = input->dim(3); const size_t outputs_count = output_list.size(); const index_t output_channels = input_channels / outputs_count; - MACE_CHECK(output_channels % 4 == 0) - << "output channels of split op must be divisible by 4"; std::vector output_shape( {input->dim(0), input->dim(1), input->dim(2), output_channels}); diff --git a/mace/ops/ops_test_util.cc b/mace/ops/ops_test_util.cc index 1cb835f60fd3a16e3c8bd12cfae38f91a8c57e99..a0575f7d9ec339c77612525c0405c1ad7ca0ad18 100644 --- a/mace/ops/ops_test_util.cc +++ b/mace/ops/ops_test_util.cc @@ -167,6 +167,9 @@ bool OpsTestNet::Setup(mace::DeviceType device) { !ws_.GetTensor(input)->is_weight()) { auto input_info = net_def.add_input_info(); input_info->set_name(input); + auto data_format = ProtoArgHelper::GetOptionalArg( + op_def_, "data_format", DataFormat::DF_NONE); + input_info->set_data_format(data_format); auto &shape = ws_.GetTensor(input)->shape(); for (auto d : shape) { input_info->add_dims(static_cast(d)); diff --git a/mace/ops/pad.cc b/mace/ops/pad.cc index aa18b7c1c519f5ce2b27967647ddc900199a01f2..ecf2ddff3099572f590209b3829941c98d323bcb 100644 --- a/mace/ops/pad.cc +++ b/mace/ops/pad.cc @@ -35,7 +35,11 @@ class PadOp : public Operation { constant_value_(Operation::GetOptionalArg( "constant_value", 0.0)) { MACE_CHECK(paddings_.size() == 8); - paddings_ = TransposeShape(paddings_, {0, 1, 6, 7, 2, 3, 4, 5}); + auto df = static_cast(Operation::GetOptionalArg( + "data_format", DataFormat::DF_NONE)); + if (df == DataFormat::NHWC) { + paddings_ = TransposeShape(paddings_, {0, 1, 6, 7, 2, 3, 4, 5}); + } } MaceStatus Run(OpContext *context) override { diff --git a/mace/ops/pad_test.cc b/mace/ops/pad_test.cc index 5de799f243e9cc51fb541f6ad5c7601e5de34cc3..c8a4d3494f2b26f4bd98b95bad174d2dc34aed66 100644 --- a/mace/ops/pad_test.cc +++ b/mace/ops/pad_test.cc @@ -34,6 +34,7 @@ void Simple() { .Output("Output") .AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0}) .AddFloatArg("constant_value", 1.0) + .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run @@ -46,6 +47,7 @@ void Simple() { .Output("TOutput") .AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0}) .AddFloatArg("constant_value", 1.0) + .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run @@ -84,6 +86,7 @@ TEST_F(PadTest, ComplexCPU) { .Output("TOutput") .AddIntsArg("paddings", {0, 0, 1, 1, 1, 1, 1, 1}) .AddFloatArg("constant_value", 1.0) + .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run @@ -120,6 +123,7 @@ void Complex(const std::vector &input_shape, .Output("TOutput") .AddIntsArg("paddings", paddings) .AddFloatArg("constant_value", 1.0) + .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run @@ -135,6 +139,7 @@ void Complex(const std::vector &input_shape, .Output("Output") .AddIntsArg("paddings", paddings) .AddFloatArg("constant_value", 1.0) + .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run diff --git a/mace/ops/reduce.cc b/mace/ops/reduce.cc index c14cd48cdad15dd4bdd29a935aa071d5ba84a7e8..bc917c97e1084a2ead9a36f93159db4c7e61fb33 100644 --- a/mace/ops/reduce.cc +++ b/mace/ops/reduce.cc @@ -90,7 +90,9 @@ class ReduceOp : public ReduceOpBase { int index = axis_[i] >= 0 ? axis_[i] : axis_[i] + input->dim_size(); - if (input->dim_size() == 4) { + auto df = static_cast(Operation::GetOptionalArg( + "data_format", DataFormat::DF_NONE)); + if (df == DataFormat::NHWC && input->dim_size() == 4) { if (index == 1 || index == 2) index = index + 1; else if (index == 3) index = 1; } diff --git a/mace/ops/reduce_test.cc b/mace/ops/reduce_test.cc index e9e804e953270f8970fd8987a8a50fe58ea2831a..115c0daf4ab176a3c6d85c2829de3a0660579545 100644 --- a/mace/ops/reduce_test.cc +++ b/mace/ops/reduce_test.cc @@ -44,6 +44,7 @@ void Simple(const std::vector &input_shape, .AddIntsArg("axis", axis) .AddIntArg("keepdims", keepdims ? 1 : 0) .AddIntArg("reduce_type", type) + .AddIntArg("data_format", DataFormat::NHWC) .Output("OutputNCHW") .Finalize(net.NewOperatorDef()); // Run @@ -55,6 +56,7 @@ void Simple(const std::vector &input_shape, .AddIntsArg("axis", axis) .AddIntArg("keepdims", keepdims ? 1 : 0) .AddIntArg("reduce_type", type) + .AddIntArg("data_format", DataFormat::NHWC) .Output("Output") .Finalize(net.NewOperatorDef()); // Run @@ -82,6 +84,7 @@ void Simple3D(const std::vector &input_shape, .AddIntsArg("axis", axis) .AddIntArg("keepdims", keepdims ? 1 : 0) .AddIntArg("reduce_type", type) + .AddIntArg("data_format", DataFormat::NHWC) .Output("Output") .Finalize(net.NewOperatorDef()); // Run @@ -585,6 +588,7 @@ void RandomTest(const std::vector &input_shape, .AddIntsArg("axis", axis) .AddIntArg("keepdims", 1) .AddIntArg("reduce_type", type) + .AddIntArg("data_format", DataFormat::NHWC) .Output("OutputNCHW") .Finalize(net.NewOperatorDef()); // Run @@ -596,6 +600,7 @@ void RandomTest(const std::vector &input_shape, .AddIntsArg("axis", axis) .AddIntArg("keepdims", 1) .AddIntArg("reduce_type", type) + .AddIntArg("data_format", DataFormat::NHWC) .Output("OPENCLOutput") .Finalize(net.NewOperatorDef()); // Run diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index 330d3fe1366d0a7cec6b91851551d030641cbee9..8f801657d3a2fdc9c26a4ec03eba8134a136df3f 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -99,6 +99,21 @@ void RegisterReshape(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::GPU, half); #endif // MACE_ENABLE_OPENCL + + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("Reshape") + .SetDevicePlacerFunc( + [](OpConstructContext *context) -> std::set { + auto op = context->operator_def(); + if (op->output_shape_size() != op->output_size()) { + return { DeviceType::CPU, DeviceType::GPU }; + } + if (op->output_shape(0).dims_size() != 4) { + return { DeviceType::CPU }; + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/softmax.cc b/mace/ops/softmax.cc index 5abb524d6e868eae520f72c299212a5d01cd3afa..1d1bb6a9cf1fd23aa62f29fed735561017f46388 100644 --- a/mace/ops/softmax.cc +++ b/mace/ops/softmax.cc @@ -410,6 +410,22 @@ void RegisterSoftmax(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Softmax", SoftmaxOp, DeviceType::GPU, half); #endif // MACE_ENABLE_OPENCL + + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("Softmax") + .SetDevicePlacerFunc( + [](OpConstructContext *context) -> std::set { + auto op = context->operator_def(); + if (op->output_shape_size() != op->output_size()) { + return { DeviceType::CPU, DeviceType::GPU }; + } + if (op->output_shape(0).dims_size() != 2 && + op->output_shape(0).dims_size() != 4) { + return { DeviceType::CPU }; + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/split.cc b/mace/ops/split.cc index d7f33965493cb1e6d0d6124334fe546cc196da86..1e72e01e2351e2eaa0251caef037ee040fc858f8 100644 --- a/mace/ops/split.cc +++ b/mace/ops/split.cc @@ -35,7 +35,9 @@ class SplitOp : public Operation { checked_(false) {} void Validate() { - if (this->Input(0)->dim_size() == 4) { + auto df = static_cast(Operation::GetOptionalArg( + "data_format", DataFormat::DF_NONE)); + if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) { if (axis_ == 3) axis_ = 1; else if (axis_ == 2) axis_ = 3; else if (axis_ == 1) axis_ = 2; @@ -139,6 +141,24 @@ void RegisterSplit(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Split", SplitOp, DeviceType::GPU, half); #endif // MACE_ENABLE_OPENCL + + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("Split") + .SetDevicePlacerFunc( + [](OpConstructContext *context) -> std::set { + auto op = context->operator_def(); + if (op->output_shape_size() != op->output_size()) { + return { DeviceType::CPU, DeviceType::GPU }; + } + int axis = ProtoArgHelper::GetOptionalArg( + *op, "axis", 3); + if (axis != 3 || op->output_shape(0).dims_size() != 4 || + (op->output_shape(0).dims()[3] % 4 != 0)) { + return { DeviceType::CPU }; + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/split_test.cc b/mace/ops/split_test.cc index 89fbbadefbb39b4a3bc6446f8c6ed58e074636f5..b03f5d76ec93def16092d7e5225ee1d55b2f7b9d 100644 --- a/mace/ops/split_test.cc +++ b/mace/ops/split_test.cc @@ -54,6 +54,7 @@ void RandomTest(const int num_outputs, int axis) { builder = builder.Output(MakeString("Output", i)); } builder.AddIntArg("T", static_cast(DataTypeToEnum::value)) + .AddIntArg("data_format", DataFormat::NHWC) .Finalize(net.NewOperatorDef()); // Run diff --git a/mace/ops/squeeze.cc b/mace/ops/squeeze.cc index bf86a84feb33026047c44951e2acdfbc30467ec2..060854b9757917eb7706d0004a64084377ba4b91 100644 --- a/mace/ops/squeeze.cc +++ b/mace/ops/squeeze.cc @@ -31,11 +31,14 @@ class SqueezeOp : public Operation { MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); if (!checked_ && D == DeviceType::CPU - && DataTypeToEnum::value != DT_UINT8 - && this->Input(0)->dim_size() == 4) { - if (axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2) { - axis_[0] = 2; - axis_[1] = 3; + && DataTypeToEnum::value != DT_UINT8) { + auto df = static_cast(Operation::GetOptionalArg( + "data_format", DataFormat::DF_NONE)); + if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) { + if (axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2) { + axis_[0] = 2; + axis_[1] = 3; + } } checked_ = true; } @@ -70,6 +73,21 @@ void RegisterSqueeze(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::GPU, float); MACE_REGISTER_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::GPU, half); #endif // MACE_ENABLE_OPENCL + MACE_REGISTER_OP_CONDITION( + op_registry, + OpConditionBuilder("Squeeze") + .SetDevicePlacerFunc( + [](OpConstructContext *context) -> std::set { + auto op = context->operator_def(); + if (op->output_shape_size() != op->output_size()) { + return { DeviceType::CPU, DeviceType::GPU }; + } + if (op->output_shape(0).dims_size() != 2 && + op->output_shape(0).dims_size() != 4) { + return { DeviceType::CPU }; + } + return { DeviceType::CPU, DeviceType::GPU }; + })); } } // namespace ops diff --git a/mace/ops/squeeze_test.cc b/mace/ops/squeeze_test.cc index b0fc972cd0479d52bbbd0eff3e96a0bdd7b0a176..512499d54094dfff6977ea22f75b7ccafbcb233d 100644 --- a/mace/ops/squeeze_test.cc +++ b/mace/ops/squeeze_test.cc @@ -30,6 +30,7 @@ void TestSqueeze(const std::vector &org_shape, OpDefBuilder("Squeeze", "SqueezeTest") .Input("Input") .AddIntsArg("axis", axis) + .AddIntArg("data_format", DataFormat::NHWC) .Output("Output") .Finalize(net.NewOperatorDef()); diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 142ec12d33fdf62a8bb989632cf006f9b7105cd3..36cb33d43110203c4a9c35982ae5284658e57944 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -259,6 +259,7 @@ class TransformerRule(Enum): TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36 FOLD_FC_RESHAPE = 37 TRANSFORM_CHANNEL_SHUFFLE = 38 + UPDATE_DATA_FORMAT = 39 class ConverterInterface(object): @@ -479,6 +480,8 @@ class ConverterOption(object): TransformerRule.ADD_OPENCL_INFORMATIONS, # for quantization entropy calibration use TransformerRule.SORT_BY_EXECUTION, + # update the data format of ops + TransformerRule.UPDATE_DATA_FORMAT, # Need to be put after SORT_BY_EXECUTION TransformerRule.ADD_QUANTIZE_TENSOR_RANGE, ] diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index a268fa3267335fa0ade3d638b1df3101a933ebe2..cba32948dacbe9098a43daa9d4d1ed06a2b3ea0a 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -96,6 +96,7 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.ADD_OPENCL_INFORMATIONS: self.add_opencl_informations, TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution, + TransformerRule.UPDATE_DATA_FORMAT: self.update_data_format, TransformerRule.CHECK_QUANTIZE_INFO: self.check_quantize_info, TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN: @@ -1358,6 +1359,24 @@ class Transformer(base_converter.ConverterInterface): out_shape.dims for out_shape in op.output_shape])) return False + def update_data_format(self): + data_format_flag = DataFormat.NHWC.value + for input_node in self._option.input_nodes.values(): + if input_node.data_format.value == DataFormat.DF_NONE.value: + data_format_flag = DataFormat.DF_NONE.value + + net = self._model + for op in net.op: + data_format_arg = ConverterUtil.get_arg( + op, MaceKeyword.mace_data_format_str) + if not data_format_arg: + data_format_arg = op.arg.add() + data_format_arg.name = MaceKeyword.mace_data_format_str + data_format_arg.i = data_format_flag + elif data_format_arg.i != data_format_flag: + data_format_arg.i = data_format_flag + return False + def quantize_nodes(self): if not self._option.quantize: return False