From 92f18fc6842ef79fd9cc785f8fdea174afca2ad9 Mon Sep 17 00:00:00 2001 From: yejianwu Date: Thu, 16 Aug 2018 16:28:40 +0800 Subject: [PATCH] support tf basic lstm on cpu --- mace/kernels/fill.h | 72 +++++++++++++++++++ mace/kernels/opencl/cl/{slice.cl => split.cl} | 2 +- mace/kernels/opencl/{slice.cc => split.cc} | 16 ++--- mace/kernels/{slice.h => split.h} | 18 ++--- mace/kernels/strided_slice.h | 20 +++++- mace/ops/fill.cc | 29 ++++++++ mace/ops/fill.h | 49 +++++++++++++ mace/ops/fill_test.cc | 65 +++++++++++++++++ mace/ops/ops_register.cc | 6 +- mace/ops/{slice.cc => split.cc} | 16 ++--- mace/ops/{slice.h => split.h} | 18 ++--- ...{slice_benchmark.cc => split_benchmark.cc} | 32 ++++----- mace/ops/{slice_test.cc => split_test.cc} | 16 ++--- mace/ops/strided_slice_test.cc | 12 ++++ .../tools/converter_tool/base_converter.py | 3 + .../converter_tool/tensorflow_converter.py | 20 ++++-- .../tools/converter_tool/transformer.py | 1 + .../opencl-kernel/opencl_kernel_configure.bzl | 2 +- 18 files changed, 328 insertions(+), 69 deletions(-) create mode 100644 mace/kernels/fill.h rename mace/kernels/opencl/cl/{slice.cl => split.cl} (95%) rename mace/kernels/opencl/{slice.cc => split.cc} (90%) rename mace/kernels/{slice.h => split.h} (89%) create mode 100644 mace/ops/fill.cc create mode 100644 mace/ops/fill.h create mode 100644 mace/ops/fill_test.cc rename mace/ops/{slice.cc => split.cc} (74%) rename mace/ops/{slice.h => split.h} (78%) rename mace/ops/{slice_benchmark.cc => split_benchmark.cc} (78%) rename mace/ops/{slice_test.cc => split_test.cc} (93%) diff --git a/mace/kernels/fill.h b/mace/kernels/fill.h new file mode 100644 index 00000000..5e172c3f --- /dev/null +++ b/mace/kernels/fill.h @@ -0,0 +1,72 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_FILL_H_ +#define MACE_KERNELS_FILL_H_ + +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +struct FillBase { + explicit FillBase(float value) : value_(value) {} + + int value_; +}; + +template +struct FillFunctor; + +template <> +struct FillFunctor : FillBase { + explicit FillFunctor(float value) : FillBase(value) {} + + MaceStatus operator()(const Tensor *shape, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + + MACE_CHECK(shape->dim_size() == 1) << "Shape must be 1-D"; + const index_t num_dims = shape->dim(0); + Tensor::MappingGuard shape_guard(shape); + const int32_t *shape_data = shape->data(); + + std::vector output_shape; + for (index_t i = 0; i < num_dims; ++i) { + MACE_CHECK(shape_data[i] > 0) << "Shape must be non-negative: " + << shape_data[i]; + output_shape.push_back(shape_data[i]); + } + + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + Tensor::MappingGuard output_guard(output); + float *output_data = output->mutable_data(); + + std::fill(output_data, output_data + output->size(), value_); + + return MACE_SUCCESS; + } +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_FILL_H_ diff --git a/mace/kernels/opencl/cl/slice.cl b/mace/kernels/opencl/cl/split.cl similarity index 95% rename from mace/kernels/opencl/cl/slice.cl rename to mace/kernels/opencl/cl/split.cl index f6b0c35a..8f93742e 100644 --- a/mace/kernels/opencl/cl/slice.cl +++ b/mace/kernels/opencl/cl/split.cl @@ -1,6 +1,6 @@ #include -__kernel void slice(KERNEL_ERROR_PARAMS +__kernel void split(KERNEL_ERROR_PARAMS GLOBAL_WORK_GROUP_SIZE_DIM3 __read_only image2d_t input, __private const int chan_blk_offset, diff --git a/mace/kernels/opencl/slice.cc b/mace/kernels/opencl/split.cc similarity index 90% rename from mace/kernels/opencl/slice.cc rename to mace/kernels/opencl/split.cc index b778e0d7..65fd6be5 100644 --- a/mace/kernels/opencl/slice.cc +++ b/mace/kernels/opencl/split.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mace/kernels/slice.h" +#include "mace/kernels/split.h" #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/kernels/opencl/helper.h" #include "mace/utils/tuner.h" @@ -21,7 +21,7 @@ namespace mace { namespace kernels { template -MaceStatus SliceFunctor::operator()( +MaceStatus SplitFunctor::operator()( const Tensor *input, const std::vector &output_list, StatsFuture *future) { @@ -29,7 +29,7 @@ MaceStatus SliceFunctor::operator()( const size_t outputs_count = output_list.size(); const index_t output_channels = input_channels / outputs_count; MACE_CHECK(output_channels % 4 == 0) - << "output channels of slice op must be divisible by 4"; + << "output channels of split op must be divisible by 4"; std::vector output_shape( {input->dim(0), input->dim(1), input->dim(2), output_channels}); @@ -46,12 +46,12 @@ MaceStatus SliceFunctor::operator()( std::set built_options; OUT_OF_RANGE_CONFIG(kernel_error_); NON_UNIFORM_WG_CONFIG; - std::string kernel_name = MACE_OBFUSCATE_SYMBOL("slice"); - built_options.emplace("-Dslice=" + kernel_name); + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("split"); + built_options.emplace("-Dsplit=" + kernel_name); built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum::value)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DataTypeToEnum::value)); - MACE_RETURN_IF_ERROR(runtime->BuildKernel("slice", + MACE_RETURN_IF_ERROR(runtime->BuildKernel("split", kernel_name, built_options, &kernel_)); @@ -116,8 +116,8 @@ MaceStatus SliceFunctor::operator()( return MACE_SUCCESS; } -template struct SliceFunctor; -template struct SliceFunctor; +template struct SplitFunctor; +template struct SplitFunctor; } // namespace kernels } // namespace mace diff --git a/mace/kernels/slice.h b/mace/kernels/split.h similarity index 89% rename from mace/kernels/slice.h rename to mace/kernels/split.h index 7ab311b0..95ff7861 100644 --- a/mace/kernels/slice.h +++ b/mace/kernels/split.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MACE_KERNELS_SLICE_H_ -#define MACE_KERNELS_SLICE_H_ +#ifndef MACE_KERNELS_SPLIT_H_ +#define MACE_KERNELS_SPLIT_H_ #include #include @@ -31,15 +31,15 @@ namespace mace { namespace kernels { -struct SliceFunctorBase { - explicit SliceFunctorBase(const int32_t axis) : axis_(axis) {} +struct SplitFunctorBase { + explicit SplitFunctorBase(const int32_t axis) : axis_(axis) {} int32_t axis_; }; template -struct SliceFunctor : SliceFunctorBase { - explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {} +struct SplitFunctor : SplitFunctorBase { + explicit SplitFunctor(const int32_t axis) : SplitFunctorBase(axis) {} MaceStatus operator()(const Tensor *input, const std::vector &output_list, @@ -89,8 +89,8 @@ struct SliceFunctor : SliceFunctorBase { #ifdef MACE_ENABLE_OPENCL template -struct SliceFunctor : SliceFunctorBase { - explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {} +struct SplitFunctor : SplitFunctorBase { + explicit SplitFunctor(const int32_t axis) : SplitFunctorBase(axis) {} MaceStatus operator()(const Tensor *input, const std::vector &output_list, @@ -104,4 +104,4 @@ struct SliceFunctor : SliceFunctorBase { } // namespace kernels } // namespace mace -#endif // MACE_KERNELS_SLICE_H_ +#endif // MACE_KERNELS_SPLIT_H_ diff --git a/mace/kernels/strided_slice.h b/mace/kernels/strided_slice.h index eab4a4d5..e966367f 100644 --- a/mace/kernels/strided_slice.h +++ b/mace/kernels/strided_slice.h @@ -169,7 +169,6 @@ struct StridedSliceFunctor { i += strides_data[0]) { *output_data++ = input_data[i]; } - } else if (input->dim_size() == 2) { for (index_t i = real_begin_indices[0]; strides_data[0] > 0 ? i < real_end_indices[0] @@ -179,7 +178,24 @@ struct StridedSliceFunctor { strides_data[1] > 0 ? j < real_end_indices[1] : j > real_end_indices[1]; j += strides_data[1]) { - *output_data++ = input_data[i * dim_stride[0] + j]; + *output_data++ = input_data[i * input->dim(1) + j]; + } + } + } else if (input->dim_size() == 3) { + for (index_t i = real_begin_indices[0]; + strides_data[0] > 0 ? i < real_end_indices[0] + : i > real_end_indices[0]; + i += strides_data[0]) { + for (index_t j = real_begin_indices[1]; + strides_data[1] > 0 ? j < real_end_indices[1] + : j > real_end_indices[1]; + j += strides_data[1]) { + for (index_t k = real_begin_indices[2]; + strides_data[2] > 0 ? k < real_end_indices[2] + : k > real_end_indices[2]; + k += strides_data[2]) { + *output_data++ = input_data[(i * input->dim(1) + j) * input->dim(2) + k]; + } } } } else { diff --git a/mace/ops/fill.cc b/mace/ops/fill.cc new file mode 100644 index 00000000..93e6dadd --- /dev/null +++ b/mace/ops/fill.cc @@ -0,0 +1,29 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/fill.h" + +namespace mace { +namespace ops { + +void Register_Fill(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Fill") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + FillOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/fill.h b/mace/ops/fill.h new file mode 100644 index 00000000..3e2c6df7 --- /dev/null +++ b/mace/ops/fill.h @@ -0,0 +1,49 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_FILL_H_ +#define MACE_OPS_FILL_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/fill.h" + +namespace mace { +namespace ops { + +template +class FillOp : public Operator { + public: + FillOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(OperatorBase::GetOptionalArg("value", 0.0f)) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *shape = this->Input(SHAPE); + Tensor *output = this->Output(OUTPUT); + return functor_(shape, output, future); + } + + private: + kernels::FillFunctor functor_; + + MACE_OP_INPUT_TAGS(SHAPE); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_FILL_H_ diff --git a/mace/ops/fill_test.cc b/mace/ops/fill_test.cc new file mode 100644 index 00000000..bc3a3363 --- /dev/null +++ b/mace/ops/fill_test.cc @@ -0,0 +1,65 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class FillTest : public OpsTestBase {}; + +namespace { +void TestFill(const std::vector &shape, + const float &value) { + // Construct graph + OpsTestNet net; + OpDefBuilder("Fill", "FillTest") + .Input("Shape") + .AddFloatArg("value", static_cast(value)) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray( + "Shape", + {static_cast(shape.size())}, + shape); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + + for (index_t i = 0; i < output->dim_size(); ++i) { + ASSERT_EQ(output->dim(i), shape[i]); + } + + const float *output_ptr = output->data(); + const index_t size = output->size(); + for (index_t i = 0; i < size; ++i) { + ASSERT_EQ(output_ptr[i], value); + } +} +} // namespace + +TEST_F(FillTest, Simple) { + TestFill({3, 2, 1}, 5.0f); + TestFill({1, 3}, -1.0f); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/ops_register.cc b/mace/ops/ops_register.cc index 886546e3..3afe66c9 100644 --- a/mace/ops/ops_register.cc +++ b/mace/ops/ops_register.cc @@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry); extern void Register_Dequantize(OperatorRegistryBase *op_registry); extern void Register_Eltwise(OperatorRegistryBase *op_registry); +extern void Register_Fill(OperatorRegistryBase *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry); extern void Register_FullyConnected(OperatorRegistryBase *op_registry); extern void Register_Gather(OperatorRegistryBase *op_registry); @@ -48,7 +49,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry); extern void Register_Reshape(OperatorRegistryBase *op_registry); extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); extern void Register_Shape(OperatorRegistryBase *op_registry); -extern void Register_Slice(OperatorRegistryBase *op_registry); +extern void Register_Split(OperatorRegistryBase *op_registry); extern void Register_Softmax(OperatorRegistryBase *op_registry); extern void Register_Stack(OperatorRegistryBase *op_registry); extern void Register_StridedSlice(OperatorRegistryBase *op_registry); @@ -84,6 +85,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_DepthwiseConv2d(this); ops::Register_Dequantize(this); ops::Register_Eltwise(this); + ops::Register_Fill(this); ops::Register_FoldedBatchNorm(this); ops::Register_FullyConnected(this); ops::Register_Gather(this); @@ -98,7 +100,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_Reshape(this); ops::Register_ResizeBilinear(this); ops::Register_Shape(this); - ops::Register_Slice(this); + ops::Register_Split(this); ops::Register_Softmax(this); ops::Register_Stack(this); ops::Register_StridedSlice(this); diff --git a/mace/ops/slice.cc b/mace/ops/split.cc similarity index 74% rename from mace/ops/slice.cc rename to mace/ops/split.cc index b6bf4b24..e5e103d7 100644 --- a/mace/ops/slice.cc +++ b/mace/ops/split.cc @@ -12,30 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mace/ops/slice.h" +#include "mace/ops/split.h" namespace mace { namespace ops { -void Register_Slice(OperatorRegistryBase *op_registry) { - MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") +void Register_Split(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split") .Device(DeviceType::CPU) .TypeConstraint("T") .Build(), - SliceOp); + SplitOp); #ifdef MACE_ENABLE_OPENCL - MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split") .Device(DeviceType::GPU) .TypeConstraint("T") .Build(), - SliceOp); + SplitOp); - MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split") .Device(DeviceType::GPU) .TypeConstraint("T") .Build(), - SliceOp); + SplitOp); #endif // MACE_ENABLE_OPENCL } diff --git a/mace/ops/slice.h b/mace/ops/split.h similarity index 78% rename from mace/ops/slice.h rename to mace/ops/split.h index 7f01162f..710cdfb3 100644 --- a/mace/ops/slice.h +++ b/mace/ops/split.h @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MACE_OPS_SLICE_H_ -#define MACE_OPS_SLICE_H_ +#ifndef MACE_OPS_SPLIT_H_ +#define MACE_OPS_SPLIT_H_ #include #include "mace/core/operator.h" -#include "mace/kernels/slice.h" +#include "mace/kernels/split.h" namespace mace { namespace ops { template -class SliceOp : public Operator { +class SplitOp : public Operator { public: - SliceOp(const OperatorDef &op_def, Workspace *ws) + SplitOp(const OperatorDef &op_def, Workspace *ws) : Operator(op_def, ws), functor_(OperatorBase::GetOptionalArg("axis", 3)) {} @@ -35,15 +35,15 @@ class SliceOp : public Operator { << "There must be at least two outputs for slicing"; const Tensor *input = this->Input(INPUT); const std::vector output_list = this->Outputs(); - const int32_t slice_axis = OperatorBase::GetOptionalArg("axis", 3); - MACE_CHECK((input->dim(slice_axis) % this->OutputSize()) == 0) + const int32_t split_axis = OperatorBase::GetOptionalArg("axis", 3); + MACE_CHECK((input->dim(split_axis) % this->OutputSize()) == 0) << "Outputs do not split input equally."; return functor_(input, output_list, future); } private: - kernels::SliceFunctor functor_; + kernels::SplitFunctor functor_; private: MACE_OP_INPUT_TAGS(INPUT); @@ -52,4 +52,4 @@ class SliceOp : public Operator { } // namespace ops } // namespace mace -#endif // MACE_OPS_SLICE_H_ +#endif // MACE_OPS_SPLIT_H_ diff --git a/mace/ops/slice_benchmark.cc b/mace/ops/split_benchmark.cc similarity index 78% rename from mace/ops/slice_benchmark.cc rename to mace/ops/split_benchmark.cc index c02dbf5c..8dea1263 100644 --- a/mace/ops/slice_benchmark.cc +++ b/mace/ops/split_benchmark.cc @@ -22,7 +22,7 @@ namespace test { namespace { template -void BMSliceHelper(int iters, +void BMSplitHelper(int iters, const std::vector &input_shape, const index_t num_outputs) { mace::testing::StopTiming(); @@ -42,7 +42,7 @@ void BMSliceHelper(int iters, BufferToImage(&net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); - auto builder = OpDefBuilder("Slice", "SliceTest"); + auto builder = OpDefBuilder("Split", "SplitTest"); builder.Input("InputImage"); for (int i = 0; i < num_outputs; ++i) { builder = builder.Output(MakeString("OutputImage", i)); @@ -51,7 +51,7 @@ void BMSliceHelper(int iters, .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); } else { - auto builder = OpDefBuilder("Slice", "SliceTest"); + auto builder = OpDefBuilder("Split", "SplitTest"); builder.Input("Input"); for (int i = 0; i < num_outputs; ++i) { builder = builder.Output(MakeString("Output", i)); @@ -73,28 +73,28 @@ void BMSliceHelper(int iters, } } // namespace -#define MACE_BM_SLICE_MACRO(N, H, W, C, NO, TYPE, DEVICE) \ +#define MACE_BM_SPLIT_MACRO(N, H, W, C, NO, TYPE, DEVICE) \ static void \ - MACE_BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \ + MACE_BM_SPLIT_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \ int iters) { \ const int64_t tot = static_cast(iters) * N * H * W * C; \ mace::testing::MaccProcessed(tot); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - BMSliceHelper(iters, {N, H, W, C}, NO); \ + BMSplitHelper(iters, {N, H, W, C}, NO); \ } \ MACE_BENCHMARK( \ - MACE_BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE) + MACE_BM_SPLIT_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE) -#define MACE_BM_SLICE(N, H, W, C, NO) \ - MACE_BM_SLICE_MACRO(N, H, W, C, NO, float, CPU); \ - MACE_BM_SLICE_MACRO(N, H, W, C, NO, float, GPU); \ - MACE_BM_SLICE_MACRO(N, H, W, C, NO, half, GPU); +#define MACE_BM_SPLIT(N, H, W, C, NO) \ + MACE_BM_SPLIT_MACRO(N, H, W, C, NO, float, CPU); \ + MACE_BM_SPLIT_MACRO(N, H, W, C, NO, float, GPU); \ + MACE_BM_SPLIT_MACRO(N, H, W, C, NO, half, GPU); -MACE_BM_SLICE(1, 32, 32, 32, 2); -MACE_BM_SLICE(1, 32, 32, 128, 2); -MACE_BM_SLICE(1, 32, 32, 256, 2); -MACE_BM_SLICE(1, 128, 128, 32, 2); -MACE_BM_SLICE(1, 128, 128, 128, 2); +MACE_BM_SPLIT(1, 32, 32, 32, 2); +MACE_BM_SPLIT(1, 32, 32, 128, 2); +MACE_BM_SPLIT(1, 32, 32, 256, 2); +MACE_BM_SPLIT(1, 128, 128, 32, 2); +MACE_BM_SPLIT(1, 128, 128, 128, 2); } // namespace test } // namespace ops diff --git a/mace/ops/slice_test.cc b/mace/ops/split_test.cc similarity index 93% rename from mace/ops/slice_test.cc rename to mace/ops/split_test.cc index b445d56a..57544d18 100644 --- a/mace/ops/slice_test.cc +++ b/mace/ops/split_test.cc @@ -17,13 +17,13 @@ #include "gmock/gmock.h" #include "mace/ops/ops_test_util.h" -#include "mace/ops/slice.h" +#include "mace/ops/split.h" namespace mace { namespace ops { namespace test { -class SliceOpTest : public OpsTestBase {}; +class SplitOpTest : public OpsTestBase {}; namespace { template @@ -53,7 +53,7 @@ void RandomTest(const int num_outputs, const int axis) { BufferToImage(&net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); - auto builder = OpDefBuilder("Slice", "SliceTest"); + auto builder = OpDefBuilder("Split", "SplitTest"); builder.Input("InputImage"); for (int i = 0; i < num_outputs; ++i) { builder = builder.Output(MakeString("OutputImage", i)); @@ -61,7 +61,7 @@ void RandomTest(const int num_outputs, const int axis) { builder.AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); } else { - auto builder = OpDefBuilder("Slice", "SliceTest").AddIntArg("axis", axis); + auto builder = OpDefBuilder("Split", "SplitTest").AddIntArg("axis", axis); builder.Input("Input"); for (int i = 0; i < num_outputs; ++i) { builder = builder.Output(MakeString("Output", i)); @@ -111,25 +111,25 @@ void RandomTest(const int num_outputs, const int axis) { } } // namespace -TEST_F(SliceOpTest, CPU) { +TEST_F(SplitOpTest, CPU) { RandomTest(2, 3); RandomTest(4, 3); RandomTest(11, 3); } -TEST_F(SliceOpTest, CPUAxis1) { +TEST_F(SplitOpTest, CPUAxis1) { RandomTest(2, 1); RandomTest(4, 1); RandomTest(11, 1); } -TEST_F(SliceOpTest, OPENCLFloat) { +TEST_F(SplitOpTest, OPENCLFloat) { RandomTest(2, 3); RandomTest(4, 3); RandomTest(11, 3); } -TEST_F(SliceOpTest, OPENCLHalf) { +TEST_F(SplitOpTest, OPENCLHalf) { RandomTest(2, 3); RandomTest(4, 3); RandomTest(11, 3); diff --git a/mace/ops/strided_slice_test.cc b/mace/ops/strided_slice_test.cc index 322f1135..d975d7be 100644 --- a/mace/ops/strided_slice_test.cc +++ b/mace/ops/strided_slice_test.cc @@ -146,6 +146,18 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank2) { 0, 3, {}, {6}); } +TEST_F(StridedSliceOpTest, TestStridedSliceRank3) { + TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {0, 0, 0}, {2, 3, 2}, {1, 2, 1}, 0, 0, 0, 0, 0, {2, 2, 2}, + {1, 2, 5, 6, 7, 8, 11, 12}); + TestStridedSlice({3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, + 6, 6}, {1, 0, 0}, {2, 1, 3}, {1, 1, 1}, 0, 0, 0, 0, 0, {1, + 1, 3}, {3, 3, 3}); + TestStridedSlice({3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, + 6, 6}, {0, 0, 0}, {2, 2, 2}, {1, 2, 1}, 0, 0, 0, 0, 0, {2, + 1, 2}, {1, 1, 3, 3}); +} + TEST_F(StridedSliceOpTest, TestSlice) { TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3}, {1, 2, 3, 4, 5, 6}); diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 33ef662c..9a5440f4 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -88,6 +88,7 @@ MaceSupportedOps = [ 'Dequantize', 'Eltwise', 'FoldedBatchNorm', + 'Fill', 'FullyConnected', 'Gather', 'Identity', @@ -101,6 +102,7 @@ MaceSupportedOps = [ 'Reshape', 'ResizeBilinear', 'Slice', + 'Split', 'Shape', 'Squeeze', 'Stack', @@ -146,6 +148,7 @@ class MaceKeyword(object): mace_constant_value_str = 'constant_value' mace_dims_str = 'dims' mace_axis_str = 'axis' + mace_num_split_str = 'num_split' mace_keepdims_str = 'keepdims' mace_shape_str = 'shape' mace_winograd_filter_transformed = 'is_filter_transformed' diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index fca6ca95..be4678ed 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -68,6 +68,7 @@ TFSupportedOps = [ 'Relu6', 'Tanh', 'Sigmoid', + 'Fill', 'FusedBatchNorm', 'AvgPool', 'MaxPool', @@ -165,6 +166,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Relu6.name: self.convert_activation, TFOpType.Tanh.name: self.convert_activation, TFOpType.Sigmoid.name: self.convert_activation, + TFOpType.Fill.name: self.convert_fill, TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm, TFOpType.AvgPool.name: self.convert_pooling, TFOpType.MaxPool.name: self.convert_pooling, @@ -458,6 +460,14 @@ class TensorflowConverter(base_converter.ConverterInterface): limit_arg.name = MaceKeyword.mace_activation_max_limit_str limit_arg.f = 6.0 + def convert_fill(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.Fill.name + + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_value_str + value_arg.f = tf_op.inputs[1].eval() + def convert_fused_batchnorm(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.FoldedBatchNorm.name @@ -763,19 +773,19 @@ class TensorflowConverter(base_converter.ConverterInterface): op.output_type.extend([mace_pb2.DT_INT32]) def convert_split(self, tf_op): - # inputs: [dim, input] axis = tf_op.inputs[0].eval().astype(np.int32) axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis - mace_check(axis == 3, 'Split with %d axis only support' % axis) input_shape = self.infer_tensor_shape(tf_op.inputs[1]) - mace_check(len(input_shape) == 4 and (input_shape[3] % 4 == 0), - "The input's 4th dimension should be a multiple of 4") op = self.convert_general_op(tf_op) - op.type = MaceOp.Slice.name + op.type = MaceOp.Split.name del op.input[0] axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str axis_arg.i = axis + num_split_arg = op.arg.add() + num_split_arg.name = MaceKeyword.mace_num_split_str + num_split_arg.i = tf_op.get_attr('num_split') + self._skip_tensor.add(tf_op.inputs[0].name) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 1e68cb14..16d9eae0 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -812,6 +812,7 @@ class Transformer(base_converter.ConverterInterface): "only support concat at " "channel dimension") arg.i = 3 + producer = self._producer[op.input[0]] input_shape = producer.output_shape[0].dims if producer.type == MaceOp.FullyConnected.name and \ diff --git a/repository/opencl-kernel/opencl_kernel_configure.bzl b/repository/opencl-kernel/opencl_kernel_configure.bzl index 0da8838d..0d1b9cf0 100644 --- a/repository/opencl-kernel/opencl_kernel_configure.bzl +++ b/repository/opencl-kernel/opencl_kernel_configure.bzl @@ -42,7 +42,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx): unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/reduce_mean.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bilinear.cl")) - unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/slice.cl")) + unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/split.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl")) -- GitLab