From b5dfcf4d96c9ea14b9013a659b182505691fb2e0 Mon Sep 17 00:00:00 2001 From: luxuhui Date: Mon, 11 May 2020 20:14:28 +0800 Subject: [PATCH] support RELU6/ArgMax/ResizeNearestNeighbor op for Caffe, fix bug. N/A Signed-off-by: Luxuhui --- mace/core/arg_helper.cc | 4 + mace/core/arg_helper.h | 7 + mace/core/ops/operator.h | 5 + mace/ops/argmax.cc | 181 +++++++++++++----- mace/ops/eltwise.cc | 11 ++ .../opencl/image/resize_nearest_neighbor.cc | 14 +- .../opencl/image/resize_nearest_neighbor.h | 4 +- mace/ops/opencl/resize_nearest_neighbor.h | 4 +- mace/ops/resize_nearest_neighbor.cc | 78 +++++--- test/ccunit/mace/ops/argmax_test.cc | 1 + third_party/caffe/caffe.proto | 13 ++ tools/python/transform/base_converter.py | 5 + tools/python/transform/caffe_converter.py | 104 ++++++++-- tools/python/transform/shape_inference.py | 91 ++++++++- .../python/transform/tensorflow_converter.py | 4 + 15 files changed, 407 insertions(+), 119 deletions(-) diff --git a/mace/core/arg_helper.cc b/mace/core/arg_helper.cc index 0b078bc7..a3f120e8 100644 --- a/mace/core/arg_helper.cc +++ b/mace/core/arg_helper.cc @@ -38,6 +38,10 @@ ProtoArgHelper::ProtoArgHelper(const NetDef &netdef) { } } +bool ProtoArgHelper::ExistArg(const std::string &arg_name) const { + return (arg_map_.count(arg_name) > 0); +} + namespace { template inline bool IsCastLossless(const InputType &value) { diff --git a/mace/core/arg_helper.h b/mace/core/arg_helper.h index 34a0c1f5..90a8ff75 100644 --- a/mace/core/arg_helper.h +++ b/mace/core/arg_helper.h @@ -41,6 +41,11 @@ class ProtoArgHelper { return ProtoArgHelper(def).GetRepeatedArgs(arg_name, default_value); } + template + static bool ExistArg(const Def &def, const std::string &arg_name) { + return ProtoArgHelper(def).ExistArg(arg_name); + } + explicit ProtoArgHelper(const OperatorDef &def); explicit ProtoArgHelper(const NetDef &netdef); @@ -55,6 +60,8 @@ class ProtoArgHelper { template std::vector GetRepeatedArgs(const std::string &arg_name) const; + bool ExistArg(const std::string &arg_name) const; + private: std::map arg_map_; }; diff --git a/mace/core/ops/operator.h b/mace/core/ops/operator.h index bed56720..2a23b4a9 100644 --- a/mace/core/ops/operator.h +++ b/mace/core/ops/operator.h @@ -64,6 +64,11 @@ class Operation { *operator_def_, name); } + bool ExistArg(const std::string &name) const { + MACE_CHECK(operator_def_, "operator_def was null!"); + return ProtoArgHelper::ExistArg(*operator_def_, name); + } + DeviceType device_type() const { return static_cast(operator_def_->device_type()); } diff --git a/mace/ops/argmax.cc b/mace/ops/argmax.cc index 9dccee7d..91a92509 100644 --- a/mace/ops/argmax.cc +++ b/mace/ops/argmax.cc @@ -24,24 +24,101 @@ namespace mace { namespace ops { -template +template class ArgMaxOp : public Operation { public: explicit ArgMaxOp(OpConstructContext *context) : Operation(context), + model_type_(static_cast(Operation::GetOptionalArg( + "framework_type", FrameworkType::TENSORFLOW))), + has_axis_(model_type_ != FrameworkType::CAFFE || + Operation::ExistArg("axis")), + top_k_(Operation::GetOptionalArg("top_k", 1)), + out_val_(Operation::GetOptionalArg("out_val", false)), axis_(Operation::GetOptionalArg("axis", 0)), - keep_dims_(Operation::GetOptionalArg("keepdims", true)), - argmin_(Operation::GetOptionalArg("argmin", false)) {} + argmin_(Operation::GetOptionalArg("argmin", false)), + keep_dims_(Operation::GetOptionalArg("keepdims", true)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); const Tensor *input = this->Input(0); - const Tensor *axis = this->InputSize() == 2 ? - this->Input(1) : nullptr; Tensor *output = this->Output(0); - MACE_CHECK(keep_dims_, "Mace only supports keep_dims ArgMax."); - MACE_CHECK(input->dim_size() > 0, "ArgMax input should not be a scalar"); + const auto input_dim_size = input->dim_size(); + MACE_CHECK(input_dim_size > 0, "ArgMax input should not be a scalar"); + const auto axis_value = GetAxisValue(input_dim_size); + MACE_RETURN_IF_ERROR(ResizeOutputTensor(output, input, axis_value)); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + auto input_data = input->data(); + + int axis_dim = 0; + int axis_dist = 0; + const auto &input_shape = input->shape(); + if (axis_value != 0) { + axis_dim = input->dim(axis_value); + axis_dist = std::accumulate(input_shape.begin() + axis_value, + input_shape.end(), + 1, std::multiplies()) / axis_dim; + } else { + axis_dim = input->dim(0); + axis_dist = 1; + } + const auto output_loop = input->size() / axis_dim; + + for (int i = 0; i < output_loop; i += 1) { + std::vector> input_data_vector(axis_dim); + const auto axis_base = i / axis_dist * axis_dim; + const auto axis_offset = i % axis_dist; + for (int d = 0; d < axis_dim; ++d) { + const auto input_idx = (axis_base + d) * axis_dist + axis_offset; + input_data_vector[d] = std::make_pair(input_data[input_idx], d); + } + + if (argmin_) { + std::partial_sort(input_data_vector.begin(), + input_data_vector.begin() + top_k_, + input_data_vector.end(), + std::less>()); + } else { + std::partial_sort(input_data_vector.begin(), + input_data_vector.begin() + top_k_, + input_data_vector.end(), + std::greater>()); + } + + if (!out_val_) { + auto output_data = output->mutable_data(); + const auto top_k_base = i / axis_dist * top_k_; + for (int j = 0; j < top_k_; ++j) { + const auto output_idx = (top_k_base + j) * axis_dist + axis_offset; + output_data[output_idx] = input_data_vector[j].second; + } + } else if (has_axis_) { // Produces max/min value per axis + auto output_data = output->mutable_data(); + const auto top_k_base = i / axis_dist * top_k_; + for (int j = 0; j < top_k_; ++j) { + auto output_idx = (top_k_base + j) * axis_dist + axis_offset; + output_data[output_idx] = input_data_vector[j].first; + } + } else { // Produces max_ind and max/min value + auto output_data = output->mutable_data(); + const auto top_k_base_pos = 2 * i * top_k_; + const auto top_k_base_value = top_k_base_pos + top_k_; + for (int j = 0; j < top_k_; ++j) { + output_data[top_k_base_pos + j] = input_data_vector[j].second; + output_data[top_k_base_value + j] = input_data_vector[j].first; + } + } + } + + return MaceStatus::MACE_SUCCESS; + } + + private: + int GetAxisValue(const index_t input_dim_size) { + const Tensor *axis = this->InputSize() == 2 ? this->Input(1) : nullptr; int axis_value = 0; if (axis != nullptr) { MACE_CHECK(axis->dim_size() == 0, @@ -52,65 +129,63 @@ class ArgMaxOp : public Operation { axis_value = axis_; } if (axis_value < 0) { - axis_value += input->dim_size(); + axis_value += input_dim_size; } - MACE_CHECK(axis_value == input->dim_size() - 1, - "Mace argmax only supports last dimension as axis"); - std::vector output_shape(input->dim_size() - 1); - for (index_t d = 0; d < input->dim_size() - 1; ++d) { - output_shape[d] = input->dim(d < axis_value ? d : d + 1); - } - MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + return axis_value; + } - Tensor::MappingGuard input_guard(input); - Tensor::MappingGuard output_guard(output); - auto input_data = input->data(); - auto output_data = output->mutable_data(); - - index_t outer_size = output->size(); - index_t inner_size = input->dim(axis_value); - - if (argmin_) { - for (index_t i = 0; i < outer_size; ++i) { - int idx = 0; - float min_value = std::numeric_limits::max(); - const T *input_ptr = input_data + i * inner_size; - for (index_t j = 0; j < inner_size; ++j) { - float input_value = input_ptr[j]; - if (input_value < min_value) { - min_value = input_value; - idx = j; - } - } - output_data[i] = idx; + MaceStatus ResizeOutputTensor(Tensor *output, const Tensor *input, + const index_t axis_value) { + auto &input_shape = input->shape(); + std::vector output_shape; + if (model_type_ == FrameworkType::CAFFE) { + auto output_dim_num = input_shape.size(); + if (output_dim_num < 3) { + output_dim_num = 3; } - } else { - for (index_t i = 0; i < outer_size; ++i) { - int idx = 0; - float max_value = std::numeric_limits::lowest(); - const T *input_ptr = input_data + i * inner_size; - for (index_t j = 0; j < inner_size; ++j) { - float input_value = input_ptr[j]; - if (input_value > max_value) { - max_value = input_value; - idx = j; - } + output_shape.assign(output_dim_num, 1); + if (has_axis_) { + // Produces max/min idx or max/min value per axis + output_shape.assign(input_shape.begin(), input_shape.end()); + output_shape[axis_value] = top_k_; + } else { + output_shape[0] = input_shape[0]; + // Produces max_ind + output_shape[2] = top_k_; + if (out_val_) { + // Produces max/min idx and max/min value + output_shape[1] = 2; } - output_data[i] = idx; + } + } else { // for Tensorflow and ONNX + output_shape.assign(input_shape.begin(), + input_shape.begin() + axis_value); + if (keep_dims_) { + output_shape.push_back(1); + } + for (size_t d = axis_value + 1; d < input_shape.size(); ++d) { + output_shape.push_back(input_shape[d]); } } - return MaceStatus::MACE_SUCCESS; + return output->Resize(output_shape); } protected: - const int axis_; - bool keep_dims_; - bool argmin_; -}; + const FrameworkType model_type_; + // for Caffe + const bool has_axis_; + const bool top_k_; + const bool out_val_; + // for ONNX and TENSORFLOW + const int axis_; + const bool argmin_; + // for ONNX + const bool keep_dims_; +}; void RegisterArgMax(OpRegistry *op_registry) { MACE_REGISTER_OP(op_registry, "ArgMax", ArgMaxOp, DeviceType::CPU, float); diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc index af447c94..6dba080a 100644 --- a/mace/ops/eltwise.cc +++ b/mace/ops/eltwise.cc @@ -953,6 +953,17 @@ class EltwiseOp : public Operation { swapped = !swapped; } + // convert tensor for caffe's boardcast + if (!has_data_format_ && input0->dim_size() == 4) { + if (input1->dim_size() == 2) { + const_cast(input1)->Reshape( + {input1->dim(0), input1->dim(1), 1, 1}); + } else if (input1->dim_size() == 3) { + const_cast(input1)->Reshape( + {input1->dim(0), input1->dim(1), input1->dim(2), 1}); + } + } + // check if we can broadcast tensor uint32_t rank_diff = static_cast(input0->dim_size() - input1->dim_size()); diff --git a/mace/ops/opencl/image/resize_nearest_neighbor.cc b/mace/ops/opencl/image/resize_nearest_neighbor.cc index 9f9dd1c8..c8206694 100644 --- a/mace/ops/opencl/image/resize_nearest_neighbor.cc +++ b/mace/ops/opencl/image/resize_nearest_neighbor.cc @@ -24,23 +24,13 @@ namespace image { MaceStatus ResizeNearestNeighborKernel::Compute( OpContext *context, const Tensor *input, - const Tensor *size, - const std::vector &dims, + const index_t out_height, + const index_t out_width, Tensor *output) { const index_t batch = input->dim(0); const index_t in_height = input->dim(1); const index_t in_width = input->dim(2); const index_t channels = input->dim(3); - index_t out_height = 0; - index_t out_width = 0; - if (dims.size() < 2) { - Tensor::MappingGuard size_mapper(size); - out_height = size->data()[0]; - out_width = size->data()[1]; - } else { - out_height = dims[0]; - out_width = dims[1]; - } const index_t channel_blocks = RoundUpDiv4(channels); const uint32_t gws[3] = {static_cast(channel_blocks), diff --git a/mace/ops/opencl/image/resize_nearest_neighbor.h b/mace/ops/opencl/image/resize_nearest_neighbor.h index 1092665e..d1463c44 100644 --- a/mace/ops/opencl/image/resize_nearest_neighbor.h +++ b/mace/ops/opencl/image/resize_nearest_neighbor.h @@ -72,8 +72,8 @@ class ResizeNearestNeighborKernel : public OpenCLResizeNearestNeighborKernel { MaceStatus Compute( OpContext *context, const Tensor *input, - const Tensor *size, - const std::vector &dims, + const index_t out_height, + const index_t out_width, Tensor *output) override; private: diff --git a/mace/ops/opencl/resize_nearest_neighbor.h b/mace/ops/opencl/resize_nearest_neighbor.h index c98fc955..3b6d7514 100644 --- a/mace/ops/opencl/resize_nearest_neighbor.h +++ b/mace/ops/opencl/resize_nearest_neighbor.h @@ -32,8 +32,8 @@ class OpenCLResizeNearestNeighborKernel { virtual MaceStatus Compute( OpContext *context, const Tensor *input, - const Tensor *size, - const std::vector &dims, + const index_t out_height, + const index_t out_width, Tensor *output) = 0; MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLResizeNearestNeighborKernel); }; diff --git a/mace/ops/resize_nearest_neighbor.cc b/mace/ops/resize_nearest_neighbor.cc index ef51b00a..e9d3370e 100644 --- a/mace/ops/resize_nearest_neighbor.cc +++ b/mace/ops/resize_nearest_neighbor.cc @@ -78,27 +78,37 @@ class ResizeNearestNeighborOp : public Operation { public: explicit ResizeNearestNeighborOp(OpConstructContext *context) : Operation(context), - align_corners_(Operation::GetOptionalArg("align_corners", - false)) {} + align_corners_(Operation::GetOptionalArg("align_corners", false)), + height_scale_(Operation::GetOptionalArg("height_scale", 0)), + width_scale_(Operation::GetOptionalArg("width_scale", 0)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); const Tensor *input = this->Input(0); - const Tensor *size = this->Input(1); - Tensor::MappingGuard size_mapper(size); Tensor *output = this->Output(0); - MACE_CHECK(input->dim_size() == 4 && size->dim_size() == 1, - "input must be 4-dimensional and size must be 1-dimensional. ", - input->dim_size(), size->dim_size()); + MACE_CHECK(input->dim_size() == 4, + "input must be 4-dimensional.", input->dim_size()); const index_t batch = input->dim(0); const index_t channels = input->dim(1); const index_t in_height = input->dim(2); const index_t in_width = input->dim(3); - const index_t out_height = size->data()[0]; - const index_t out_width = size->data()[1]; + index_t out_height = 0; + index_t out_width = 0; + if (height_scale_ > 0) { // for Caffe + out_height = static_cast(height_scale_ * in_height); + out_width = static_cast(width_scale_ * in_width); + } else { // for tensor (Tf and ONNX) + const Tensor *size = this->Input(1); + Tensor::MappingGuard size_mapper(size); + MACE_CHECK(size->dim_size() == 1, + "size must be 1-dimensional.", size->dim_size()); + out_height = size->data()[0]; + out_width = size->data()[1]; + } + MACE_CHECK(out_height > 0 && out_width > 0, out_height, out_width); std::vector out_shape{batch, channels, out_height, out_width}; MACE_RETURN_IF_ERROR(output->Resize(out_shape)); @@ -114,14 +124,15 @@ class ResizeNearestNeighborOp : public Operation { return MaceStatus::MACE_SUCCESS; } - float height_scale = - common::utils::CalculateResizeScale(in_height, - out_height, - align_corners_); - float width_scale = - common::utils::CalculateResizeScale(in_width, - out_width, - align_corners_); + // Caffe's scale is the opposite of ours + float height_scale = height_scale_ > 0 ? 1 / height_scale_ : + common::utils::CalculateResizeScale(in_height, + out_height, + align_corners_); + float width_scale = width_scale_ > 0 ? 1 / width_scale_ : + common::utils::CalculateResizeScale(in_width, + out_width, + align_corners_); ResizeImageNCHW(context, input_data, batch, @@ -139,6 +150,8 @@ class ResizeNearestNeighborOp : public Operation { private: bool align_corners_; + float height_scale_; + float width_scale_; }; #ifdef MACE_ENABLE_OPENCL @@ -146,7 +159,9 @@ template<> class ResizeNearestNeighborOp : public Operation { public: explicit ResizeNearestNeighborOp(OpConstructContext *context) - : Operation(context), dim_(Operation::GetRepeatedArgs("dim")) { + : Operation(context), dim_(Operation::GetRepeatedArgs("dim")), + height_scale_(Operation::GetOptionalArg("height_scale", 0)), + width_scale_(Operation::GetOptionalArg("width_scale", 0)) { bool align_corners = Operation::GetOptionalArg( "align_corners", false); if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) { @@ -158,17 +173,34 @@ class ResizeNearestNeighborOp : public Operation { } MaceStatus Run(OpContext *context) override { const Tensor *input = this->Input(0); - const Tensor *size = this->Input(1); Tensor *output = this->Output(0); - MACE_CHECK(input->dim_size() == 4 && size->dim_size() == 1, - "input must be 4-dimensional and size must be 1-dimensional.", - input->dim_size(), size->dim_size()); + MACE_CHECK(input->dim_size() == 4, + "input must be 4-dimensional.", input->dim_size()); + + index_t out_height = 0; + index_t out_width = 0; + if (height_scale_ > 0) { // for Caffe + out_height = static_cast(height_scale_ * input->dim(1)); + out_width = static_cast(width_scale_ * input->dim(2)); + } else if (dim_.size() < 2) { // for variable tensor (Tf and ONNX) + const Tensor *size = this->Input(1); + Tensor::MappingGuard size_mapper(size); + MACE_CHECK(size->dim_size() == 1, + "size must be 1-dimensional.", size->dim_size()); + out_height = size->data()[0]; + out_width = size->data()[1]; + } else { // for const tensor (Tf and ONNX) + out_height = dim_[0]; + out_width = dim_[1]; + } - return kernel_->Compute(context, input, size, dim_, output); + return kernel_->Compute(context, input, out_height, out_width, output); } private: std::vector dim_; + float height_scale_; + float width_scale_; std::unique_ptr kernel_; }; #endif // MACE_ENABLE_OPENCL diff --git a/test/ccunit/mace/ops/argmax_test.cc b/test/ccunit/mace/ops/argmax_test.cc index 20e5b3f1..4f36f52e 100644 --- a/test/ccunit/mace/ops/argmax_test.cc +++ b/test/ccunit/mace/ops/argmax_test.cc @@ -37,6 +37,7 @@ void ArgMaxTest(const std::vector &input_shape, .Input("Input") .Input("axis") .Output("Output") + .AddIntArg("keepdims", 0) .OutputType({DT_INT32}) .Finalize(net.NewOperatorDef()); // Run diff --git a/third_party/caffe/caffe.proto b/third_party/caffe/caffe.proto index baf38801..d4e03464 100644 --- a/third_party/caffe/caffe.proto +++ b/third_party/caffe/caffe.proto @@ -1833,6 +1833,8 @@ message V1LayerParameter { optional TransformationParameter transform_param = 36; optional LossParameter loss_param = 42; optional V0LayerParameter layer = 1; + optional ResizeNearestParameter resize_nearest_param = 204; + optional GroupNormParameter group_norm_param = 205; } // DEPRECATED: V0LayerParameter is the old way of specifying layer parameters @@ -1946,3 +1948,14 @@ message ShuffleChannelParameter { message L2NormalizationParameter { optional int32 axis = 1 [default = 1]; } + +message GroupNormParameter { + optional float eps = 1 [default = 1e-5]; + optional int32 group_num = 2 [default = 32]; +} + +message ResizeNearestParameter { + optional float height_scale=1 [default = 2.0]; + optional float width_scale =2 [default = 2.0]; + +} diff --git a/tools/python/transform/base_converter.py b/tools/python/transform/base_converter.py index 89a16684..5c81e04a 100644 --- a/tools/python/transform/base_converter.py +++ b/tools/python/transform/base_converter.py @@ -49,6 +49,7 @@ class ActivationType(Enum): TANH = 4 SIGMOID = 5 LEAKYRELU = 6 + RELU6 = 7 class EltwiseType(Enum): @@ -221,6 +222,8 @@ class MaceKeyword(object): mace_batch_to_space_crops_str = 'crops' mace_paddings_str = 'paddings' mace_align_corners_str = 'align_corners' + mace_height_scale_str = 'height_scale' + mace_width_scale_str = 'width_scale' mace_space_batch_block_shape_str = 'block_shape' mace_space_depth_block_size_str = 'block_size' mace_constant_value_str = 'constant_value' @@ -257,6 +260,8 @@ class MaceKeyword(object): mace_epsilon_str = 'epsilon' mace_reduce_type_str = 'reduce_type' mace_argmin_str = 'argmin' + mace_out_val_str = 'out_val' + mace_top_k_str = 'top_k' mace_round_mode_str = 'round_mode' mace_min_size_str = 'min_size' mace_max_size_str = 'max_size' diff --git a/tools/python/transform/caffe_converter.py b/tools/python/transform/caffe_converter.py index a2482a90..35ce6dc6 100644 --- a/tools/python/transform/caffe_converter.py +++ b/tools/python/transform/caffe_converter.py @@ -161,6 +161,7 @@ class CaffeConverter(base_converter.ConverterInterface): } activation_type = { 'ReLU': ActivationType.RELU, + 'ReLU6': ActivationType.RELUX, 'PReLU': ActivationType.PRELU, 'TanH': ActivationType.TANH, 'Sigmoid': ActivationType.SIGMOID, @@ -175,6 +176,7 @@ class CaffeConverter(base_converter.ConverterInterface): 'Eltwise': self.convert_elementwise, 'Add': self.convert_add, 'ReLU': self.convert_activation, + 'ReLU6': self.convert_activation, 'TanH': self.convert_activation, 'Sigmoid': self.convert_activation, 'PReLU': self.convert_activation, @@ -196,7 +198,9 @@ class CaffeConverter(base_converter.ConverterInterface): 'L2Normalization': self.convert_lpnorm, 'L1Normalization': self.convert_lpnorm, 'MVN': self.convert_MVN, - 'Bias': self.convert_Bias, + 'Bias': self.convert_bias, + 'ArgMax': self.convert_argmax, + 'ResizeNearest': self.convert_resize_nearest, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -254,7 +258,7 @@ class CaffeConverter(base_converter.ConverterInterface): for op in ops: for i in six.moves.range(len(op.output)): original_output_name = op.output[i].split('#')[0] - if original_output_name not in visited and\ + if original_output_name not in visited and \ original_output_name not in self._option.input_nodes: self.replace_input_name( consumers.get(op.output[i], []), @@ -456,6 +460,7 @@ class CaffeConverter(base_converter.ConverterInterface): filter_data = caffe_op.blobs[0] self.add_tensor(filter_tensor_name, filter_data.shape, mace_pb2.DT_FLOAT, filter_data) + print("convert conv2d, the filter shape is: ", filter_data.shape) op.input.extend([filter_tensor_name]) if len(caffe_op.blobs) == 2: @@ -499,16 +504,18 @@ class CaffeConverter(base_converter.ConverterInterface): self.add_tensor(alpha_tensor_name, alpha_data.reshape(-1).shape, mace_pb2.DT_FLOAT, alpha_data) op.input.extend([alpha_tensor_name]) - - negative_slope = caffe_op.layer.relu_param.negative_slope - if caffe_op.type == 'ReLU' and negative_slope != 0: - param_arg = op.arg.add() - param_arg.name = MaceKeyword.mace_activation_leakyrelu_coefficient_str # noqa - param_arg.f = caffe_op.layer.relu_param.negative_slope - - type_arg.s = six.b(ActivationType.LEAKYRELU.name) - - if caffe_op.type == 'Clip': + elif caffe_op.type == 'ReLU': + negative_slope = caffe_op.layer.relu_param.negative_slope + if negative_slope != 0: + param_arg = op.arg.add() + param_arg.name = MaceKeyword.mace_activation_leakyrelu_coefficient_str # noqa + param_arg.f = caffe_op.layer.relu_param.negative_slope + type_arg.s = six.b(ActivationType.LEAKYRELU.name) + elif caffe_op.type == 'ReLU6': + limit_arg = op.arg.add() + limit_arg.name = MaceKeyword.mace_activation_max_limit_str + limit_arg.f = 6.0 + elif caffe_op.type == 'Clip': mace_check(caffe_op.layer.clip_param.min == 0, "Mace only supports min == 0 Clip op") limit_arg = op.arg.add() @@ -668,11 +675,12 @@ class CaffeConverter(base_converter.ConverterInterface): type_arg.name = MaceKeyword.mace_element_type_str type_arg.i = EltwiseType.PROD.value - scale_tensor_name = scale_op_name + '_scale' - scale_data = caffe_op.blobs[0] - self.add_tensor(scale_tensor_name, scale_data.shape, - mace_pb2.DT_FLOAT, scale_data) - op.input.extend([scale_tensor_name]) + if len(caffe_op.blobs) >= 1: + scale_tensor_name = scale_op_name + '_scale' + scale_data = caffe_op.blobs[0] + self.add_tensor(scale_tensor_name, scale_data.shape, + mace_pb2.DT_FLOAT, scale_data) + op.input.extend([scale_tensor_name]) if len(caffe_op.blobs) == 2: bias_tensor_name = scale_op_name + '_offset' @@ -802,8 +810,9 @@ class CaffeConverter(base_converter.ConverterInterface): mace_check(step_w_arg.f > 0, "step_w should be larger than 0.") if param.HasField('step'): - mace_check(not param.HasField('step_h') and not param.HasField('step_w'), # noqa - "Either step or step_h/step_w should be specified; not both.") # noqa + mace_check( + not param.HasField('step_h') and not param.HasField('step_w'), + "Either step or step_h/step_w should be specified; not both.") mace_check(param.step > 0, "step should be larger than 0.") step_h_arg.f = param.step step_w_arg.f = param.step @@ -869,7 +878,7 @@ class CaffeConverter(base_converter.ConverterInterface): eps_arg.name = MaceKeyword.mace_epsilon_str eps_arg.f = param.eps - def convert_Bias(self, caffe_op): + def convert_bias(self, caffe_op): op = self.convert_general_op(caffe_op) op.type = MaceOp.BiasAdd.name param = caffe_op.layer.bias_param @@ -882,3 +891,58 @@ class CaffeConverter(base_converter.ConverterInterface): mace_check(param.axis == 0 or param.axis == 1, "BiasAdd only support axis with 0 or 1.") axis_arg.i = param.axis + if len(caffe_op.blobs) >= 1: + bias_tensor_name = op.name + '_bias' + bias_data = caffe_op.blobs[0] + self.add_tensor(bias_tensor_name, bias_data.shape, + mace_pb2.DT_FLOAT, bias_data) + op.input.extend([bias_tensor_name]) + + def convert_resize_nearest(self, caffe_op): + op = self.convert_general_op(caffe_op) + op.type = MaceOp.ResizeNearestNeighbor.name + + align_corners_arg = op.arg.add() + align_corners_arg.name = MaceKeyword.mace_align_corners_str + align_corners_arg.i = 0 + + height_scale_arg = op.arg.add() + height_scale_arg.name = MaceKeyword.mace_height_scale_str + width_scale_arg = op.arg.add() + width_scale_arg.name = MaceKeyword.mace_width_scale_str + if hasattr(caffe_op, 'layer') and \ + hasattr(caffe_op.layer, 'resize_nearest_param'): + param = caffe_op.layer.resize_nearest_param + height_scale_arg.f = param.height_scale + width_scale_arg.f = param.width_scale + else: + height_scale_arg.f = 2.0 + width_scale_arg.f = 2.0 + + def convert_argmax(self, caffe_op): + op = self.convert_general_op(caffe_op) + op.type = MaceOp.ArgMax.name + + out_max_val = False + if hasattr(caffe_op, 'layer') and \ + hasattr(caffe_op.layer, 'argmax_param'): + param = caffe_op.layer.argmax_param + if hasattr(param, 'out_max_val'): + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_out_val_str + axis_arg.i = param.out_max_val + out_max_val = param.out_max_val + + if hasattr(param, MaceKeyword.mace_top_k_str): + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_top_k_str + axis_arg.i = param.top_k + + if hasattr(param, MaceKeyword.mace_axis_str): + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = param.axis + if out_max_val: + op.output_type.extend([mace_pb2.DT_FLOAT]) + else: + op.output_type.extend([mace_pb2.DT_INT32]) diff --git a/tools/python/transform/shape_inference.py b/tools/python/transform/shape_inference.py index 2a7b43b6..e93f862d 100644 --- a/tools/python/transform/shape_inference.py +++ b/tools/python/transform/shape_inference.py @@ -36,7 +36,7 @@ class ShapeInference(object): MaceOp.Deconv2D.name: self.infer_shape_deconv, MaceOp.DepthwiseConv2d.name: self.infer_shape_conv_pool_shape, MaceOp.DepthwiseDeconv2d.name: self.infer_shape_deconv, - MaceOp.Eltwise.name: self.infer_shape_general, + MaceOp.Eltwise.name: self.infer_shape_eltwise, MaceOp.BatchNorm.name: self.infer_shape_general, MaceOp.AddN.name: self.infer_shape_general, MaceOp.Activation.name: self.infer_shape_general, @@ -54,6 +54,9 @@ class ShapeInference(object): MaceOp.ResizeBilinear.name: self.infer_shape_resize_bilinear, MaceOp.LpNorm.name: self.infer_shape_general, MaceOp.MVNorm.name: self.infer_shape_general, + MaceOp.ResizeNearestNeighbor.name: + self.infer_shape_nearest_neighbor, + MaceOp.ArgMax.name: self.infer_shape_argmax, } self._net = net @@ -131,7 +134,7 @@ class ShapeInference(object): output_shape[0] = input_shape[0] if ConverterUtil.data_format(op) == DataFormat.NCHW \ - and ConverterUtil.filter_format(self._net) == DataFormat.OIHW: # noqa + and ConverterUtil.filter_format(self._net) == DataFormat.OIHW: # filter format: OIHW if op.type == MaceOp.DepthwiseConv2d.name: output_shape[1] = filter_shape[0] * filter_shape[1] @@ -172,7 +175,7 @@ class ShapeInference(object): MaceKeyword.mace_group_str) output_shape[0] = input_shape[0] if ConverterUtil.data_format(op) == DataFormat.NCHW \ - and ConverterUtil.filter_format(self._net) == DataFormat.OIHW: # noqa + and ConverterUtil.filter_format(self._net) == DataFormat.OIHW: # filter format: IOHW output_shape[1] = filter_shape[1] if group_arg is not None and group_arg.i > 1: @@ -250,9 +253,12 @@ class ShapeInference(object): input_shape = list(self._output_shape_cache[op.input[0]]) input_w = input_shape[3] input_h = input_shape[2] - min_size = ConverterUtil.get_arg(op, MaceKeyword.mace_min_size_str).floats # noqa - max_size = ConverterUtil.get_arg(op, MaceKeyword.mace_max_size_str).floats # noqa - aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa + min_size = \ + ConverterUtil.get_arg(op, MaceKeyword.mace_min_size_str).floats + max_size = \ + ConverterUtil.get_arg(op, MaceKeyword.mace_max_size_str).floats + aspect_ratio = \ + ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats num_prior = len(aspect_ratio) * len(min_size) + len(max_size) output_shape[2] = int(num_prior * input_h * input_w * 4) @@ -282,7 +288,8 @@ class ShapeInference(object): else: output_shape = [] axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i - end_axis = ConverterUtil.get_arg(op, MaceKeyword.mace_end_axis_str).i # noqa + end_axis = ConverterUtil.get_arg(op, + MaceKeyword.mace_end_axis_str).i end_axis = end_axis if end_axis > 0 else end_axis + len( list(self._output_shape_cache[op.input[0]])) dim = 1 @@ -310,3 +317,73 @@ class ShapeInference(object): mace_check(False, "format %s is not supported" % ConverterUtil.data_format(op)) self.add_output_shape(op, [output_shape]) + + def infer_shape_nearest_neighbor(self, op): + input_shape = self._output_shape_cache[op.input[0]] + height_scale = \ + ConverterUtil.get_arg(op, MaceKeyword.mace_height_scale_str).f + width_scale = \ + ConverterUtil.get_arg(op, MaceKeyword.mace_width_scale_str).f + if ConverterUtil.data_format(op) == DataFormat.NCHW: + output_shape = [input_shape[0], input_shape[1], + int(input_shape[2] * height_scale), + int(input_shape[3] * width_scale)] + elif ConverterUtil.data_format(op) == DataFormat.NHWC: + output_shape = [input_shape[0], int(input_shape[2] * height_scale), + int(input_shape[3] * width_scale), input_shape[3]] + else: + output_shape = [] + mace_check(False, "format %s is not supported" + % ConverterUtil.data_format(op)) + self.add_output_shape(op, [output_shape]) + + def infer_shape_argmax(self, op): + input_shape = self._output_shape_cache[op.input[0]] + output_dim_num = len(input_shape) + if output_dim_num < 3: + output_dim_num = 3 + + axis_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str) + has_axis = (axis_arg is not None) + axis_value = 0 + if has_axis: + axis_value = axis_arg.i + if axis_value < 0: + axis_value = len(input_shape) + axis_value + + top_k = ConverterUtil.get_arg(op, MaceKeyword.mace_top_k_str).i + mace_check(top_k >= 1, "Invalid top_k value") + out_val = ConverterUtil.get_arg(op, MaceKeyword.mace_out_val_str).i + + if has_axis: # Produces max_ind or max_val per axis + output_shape = input_shape + output_shape[axis_value] = top_k + else: + output_shape = [1] * output_dim_num + output_shape[0] = input_shape[0] + output_shape[2] = top_k + if out_val: # Produces max_ind and max_val + output_shape[1] = 2 + + self.add_output_shape(op, [output_shape]) + + def infer_shape_eltwise(self, op): + input_num = len(op.input) + mace_check(input_num > 0, "input num should > 0") + + max_idx = 0 + max_input_size = 0 + for i in range(0, input_num): + mace_check(op.input[i] in self._output_shape_cache, + "Op %s input %s does not exist" + % (op.name, op.input[i])) + input_shape = self._output_shape_cache[op.input[i]] + input_size = 1 + for k in range(0, len(input_shape)): + input_size *= input_shape[k] + if input_size > max_input_size: + max_idx = i + max_input_size = input_size + + input_max_shape = self._output_shape_cache[op.input[max_idx]] + self.add_output_shape(op, [input_max_shape]) diff --git a/tools/python/transform/tensorflow_converter.py b/tools/python/transform/tensorflow_converter.py index c57a83fc..dc85cfeb 100644 --- a/tools/python/transform/tensorflow_converter.py +++ b/tools/python/transform/tensorflow_converter.py @@ -1046,6 +1046,10 @@ class TensorflowConverter(base_converter.ConverterInterface): op.type = MaceOp.ArgMax.name op.output_type.extend([mace_pb2.DT_INT32]) + keep_dims_arg = op.arg.add() + keep_dims_arg.name = MaceKeyword.mace_keepdims_str + keep_dims_arg.i = 0 + def convert_split(self, tf_op): op = self.convert_general_op(tf_op) num_or_size_splits = tf_op.get_attr('num_split') -- GitLab