From 300466c76231769e270be08e18a6b8e2e32f8f51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Mon, 4 Jun 2018 19:23:34 +0800 Subject: [PATCH] Implement gather op --- mace/core/operator.cc | 2 + mace/kernels/gather.h | 105 +++++++++++++++++++++++++++++++++++ mace/ops/gather.cc | 29 ++++++++++ mace/ops/gather.h | 51 +++++++++++++++++ mace/ops/gather_benchmark.cc | 87 +++++++++++++++++++++++++++++ mace/ops/gather_test.cc | 82 +++++++++++++++++++++++++++ 6 files changed, 356 insertions(+) create mode 100644 mace/kernels/gather.h create mode 100644 mace/ops/gather.cc create mode 100644 mace/ops/gather.h create mode 100644 mace/ops/gather_benchmark.cc create mode 100644 mace/ops/gather_test.cc diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 25504d5b..e5355b2d 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -89,6 +89,7 @@ extern void Register_Dequantize(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); extern void Register_FullyConnected(OperatorRegistry *op_registry); +extern void Register_Gather(OperatorRegistry *op_registry); extern void Register_LocalResponseNorm(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry); extern void Register_Pad(OperatorRegistry *op_registry); @@ -130,6 +131,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_Eltwise(this); ops::Register_FoldedBatchNorm(this); ops::Register_FullyConnected(this); + ops::Register_Gather(this); ops::Register_LocalResponseNorm(this); ops::Register_MatMul(this); ops::Register_Pad(this); diff --git a/mace/kernels/gather.h b/mace/kernels/gather.h new file mode 100644 index 00000000..101a60e3 --- /dev/null +++ b/mace/kernels/gather.h @@ -0,0 +1,105 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_GATHER_H_ +#define MACE_KERNELS_GATHER_H_ + +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +struct GatherBase { + explicit GatherBase(int axis, float y) : axis_(axis), y_(y) {} + + int axis_; + float y_; +}; + +template +struct GatherFunctor; + +template <> +struct GatherFunctor : GatherBase { + explicit GatherFunctor(int axis, float y) : GatherBase(axis, y) {} + + MaceStatus operator()(const Tensor *params, + const Tensor *indices, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + std::vector output_shape; + if (axis_ < 0) { + axis_ += params->dim_size(); + } + MACE_CHECK(axis_ >= 0 && axis_ < params->dim_size(), + "axis is out of bound: ", axis_); + output_shape.insert(output_shape.end(), params->shape().begin(), + params->shape().begin() + axis_); + output_shape.insert(output_shape.end(), indices->shape().begin(), + indices->shape().end()); + output_shape.insert(output_shape.end(), + params->shape().begin() + (axis_ + 1), + params->shape().end()); + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + Tensor::MappingGuard indices_guard(indices); + Tensor::MappingGuard params_guard(params); + Tensor::MappingGuard output_guard(output); + const int32_t *indices_data = indices->data(); + const float *params_data = params->data(); + float *output_data = output->mutable_data(); + + index_t axis_dim_size = params->dim(axis_); + index_t lhs_size = std::accumulate(params->shape().begin(), + params->shape().begin() + axis_, 1, + std::multiplies()); + index_t rhs_size = + std::accumulate(params->shape().begin() + (axis_ + 1), + params->shape().end(), 1, std::multiplies()); + index_t index_size = indices->size(); + +#pragma omp parallel for collapse(2) + for (index_t l = 0; l < lhs_size; ++l) { + for (index_t idx = 0; idx < index_size; ++idx) { + MACE_ASSERT(indices_data[idx] < axis_dim_size, "idx out of bound: ", + indices_data[idx]); + memcpy( + output_data + ((l * index_size) + idx) * rhs_size, + params_data + ((l * axis_dim_size) + indices_data[idx]) * rhs_size, + sizeof(float) * rhs_size); + } + } + + if (std::fabs(y_ - 1.0) > 1e-6) { +#pragma omp parallel for + for (index_t i = 0; i < output->size(); ++i) { + output_data[i] *= y_; + } + } + + return MACE_SUCCESS; + } +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_GATHER_H_ diff --git a/mace/ops/gather.cc b/mace/ops/gather.cc new file mode 100644 index 00000000..bc9687cf --- /dev/null +++ b/mace/ops/gather.cc @@ -0,0 +1,29 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/gather.h" + +namespace mace { +namespace ops { + +void Register_Gather(OperatorRegistry *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Gather") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + GatherOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/gather.h b/mace/ops/gather.h new file mode 100644 index 00000000..37689b30 --- /dev/null +++ b/mace/ops/gather.h @@ -0,0 +1,51 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_GATHER_H_ +#define MACE_OPS_GATHER_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/gather.h" + +namespace mace { +namespace ops { + +template +class GatherOp : public Operator { + public: + GatherOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(OperatorBase::GetOptionalArg("axis", 0), + OperatorBase::GetOptionalArg("y", 1.0)) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *params = this->Input(PARAMS); + const Tensor *indices = this->Input(INDICES); + Tensor *output = this->Output(OUTPUT); + + return functor_(params, indices, output, future); + } + + private: + kernels::GatherFunctor functor_; + + protected: + MACE_OP_INPUT_TAGS(PARAMS, INDICES); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_GATHER_H_ diff --git a/mace/ops/gather_benchmark.cc b/mace/ops/gather_benchmark.cc new file mode 100644 index 00000000..f55b7462 --- /dev/null +++ b/mace/ops/gather_benchmark.cc @@ -0,0 +1,87 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/kernels/gather.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void GatherBenchmark(int iters, + index_t n, + index_t index_len, + index_t vocab_len, + index_t embedding_len) { + mace::testing::StopTiming(); + static unsigned int seed = time(NULL); + + OpsTestNet net; + std::vector index(index_len); + for (int i = 0; i < index_len; ++i) { + index[i] = rand_r(&seed) % vocab_len; + } + net.AddInputFromArray("Indices", {n, index_len}, index); + net.AddRandomInput("Params", {vocab_len, embedding_len}); + + OpDefBuilder("Gather", "GatherTest") + .Input("Params") + .Input("Indices") + .AddIntArg("axis", 0) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 2; ++i) { + net.RunOp(D); + net.Sync(); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + net.Sync(); + } +} +} // namespace + +#define MACE_BM_GATHER_MACRO(N, IND, VOC, EMBED, TYPE, DEVICE) \ + static void \ + MACE_BM_GATHER##_##N##_##IND##_##VOC##_##EMBED##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * IND * EMBED; \ + mace::testing::MaccProcessed(0); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + GatherBenchmark(iters, N, IND, VOC, EMBED); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_GATHER##_##N##_##IND##_##VOC##_##EMBED##_##TYPE##_##DEVICE) + +#define MACE_BM_GATHER(N, INDEX, VOCAB, EMBEDDING) \ + MACE_BM_GATHER_MACRO(N, INDEX, VOCAB, EMBEDDING, float, CPU); + +MACE_BM_GATHER(1, 7, 48165, 256); +MACE_BM_GATHER(1, 20, 48165, 256); +MACE_BM_GATHER(1, 100, 48165, 256); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/gather_test.cc b/mace/ops/gather_test.cc new file mode 100644 index 00000000..3a35b338 --- /dev/null +++ b/mace/ops/gather_test.cc @@ -0,0 +1,82 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class GatherOpTest : public OpsTestBase {}; + +namespace { +void TestGather(const std::vector &weight_shape, + const std::vector &weight, + const std::vector &input_shape, + const std::vector &input, + const int axis, + const float y, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + + net.AddInputFromArray("Params", weight_shape, weight); + net.AddInputFromArray("Indices", input_shape, input); + + OpDefBuilder("Gather", "GatherTest") + .Input("Params") + .Input("Indices") + .AddIntArg("axis", axis) + .AddFloatArg("y", y) + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(CPU); + + auto expected = CreateTensor(output_shape, output); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} +} // namespace + +TEST_F(GatherOpTest, CPUScalarIndex) { + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, + {}, {5}, 0, 2.0, {2}, {20, 22}); +} + +TEST_F(GatherOpTest, CPURank1Index) { + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, + {3}, {2, 4, 6}, 0, 1.0, {3, 2}, {4, 5, 8, 9, 12, 13}); +} + +TEST_F(GatherOpTest, CPURank1IndexWithAxis1) { + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, + {1}, {1}, 1, 1.0, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19}); +} + +TEST_F(GatherOpTest, CPURankHighIndex) { + TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, + {1, 3}, {2, 4, 6}, 0, 1.0, {1, 3, 2}, {4, 5, 8, 9, 12, 13}); +} + +} // namespace test +} // namespace ops +} // namespace mace -- GitLab