From 156ffc75685734f4d26b63c2510aad84ea28ba8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Fri, 30 Mar 2018 18:32:34 +0800 Subject: [PATCH] Support inception_v3 neon --- mace/kernels/arm/conv_2d.cc | 2 +- mace/ops/concat.cc | 6 +++++- mace/python/tools/tf_converter_lib.py | 5 ++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mace/kernels/arm/conv_2d.cc b/mace/kernels/arm/conv_2d.cc index ccd17e9a..e50cac08 100644 --- a/mace/kernels/arm/conv_2d.cc +++ b/mace/kernels/arm/conv_2d.cc @@ -159,7 +159,6 @@ void Conv2dFunctor::operator()(const Tensor *input, auto filter_data = filter->data(); auto bias_data = bias == nullptr ? nullptr : bias->data(); auto output_data = output->mutable_data(); - memset(output_data, 0, sizeof(float) * batch * channels * height * width); if (USE_WINOGRAD && filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1 @@ -301,6 +300,7 @@ void Conv2dFunctor::operator()(const Tensor *input, std::vector extra_output_shape {batch, channels, extra_output_height, extra_output_width}; padded_output_.Resize(extra_output_shape); + padded_output_.Clear(); pad_output_ptr = &padded_output_; } float *pad_output_data = pad_output_ptr->mutable_data(); diff --git a/mace/ops/concat.cc b/mace/ops/concat.cc index 2e6dbc8f..30d3cec0 100644 --- a/mace/ops/concat.cc +++ b/mace/ops/concat.cc @@ -13,7 +13,6 @@ void Register_Concat(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), ConcatOp); - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat") .Device(DeviceType::OPENCL) .TypeConstraint("T") @@ -24,6 +23,11 @@ void Register_Concat(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), ConcatOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat") + .Device(DeviceType::NEON) + .TypeConstraint("T") + .Build(), + ConcatOp); } } // namespace ops diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 61f3e926..22732a18 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -663,7 +663,10 @@ class TFConverter(object): op_def.output.extend([output.name for output in op.outputs]) axis_arg = op_def.arg.add() axis_arg.name = 'axis' - axis_arg.i = get_input_tensor(op, len(op.inputs) - 1).eval().astype(np.int32) + axis = get_input_tensor(op, len(op.inputs) - 1).eval().astype(np.int32) + if self.device == 'neon' and axis == 3: + axis = 1 + axis_arg.i = axis self.add_output_shape(op.outputs, op_def) self.resolved_ops[op.name] = 1 self.unused_tensor.add(get_input_tensor(op, len(op.inputs) - 1).name) -- GitLab