提交 2b1defc7 编写于 作者: 李滨

Merge branch 'mnmt' into 'master'

Implement gather op

See merge request !555
......@@ -89,6 +89,7 @@ extern void Register_Dequantize(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_Gather(OperatorRegistry *op_registry);
extern void Register_LocalResponseNorm(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_Pad(OperatorRegistry *op_registry);
......@@ -130,6 +131,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_Eltwise(this);
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
ops::Register_Gather(this);
ops::Register_LocalResponseNorm(this);
ops::Register_MatMul(this);
ops::Register_Pad(this);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_GATHER_H_
#define MACE_KERNELS_GATHER_H_
#include <algorithm>
#include <functional>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
struct GatherBase {
explicit GatherBase(int axis, float y) : axis_(axis), y_(y) {}
int axis_;
float y_;
};
template <DeviceType D, typename T>
struct GatherFunctor;
template <>
struct GatherFunctor<DeviceType::CPU, float> : GatherBase {
explicit GatherFunctor(int axis, float y) : GatherBase(axis, y) {}
MaceStatus operator()(const Tensor *params,
const Tensor *indices,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
std::vector<index_t> output_shape;
if (axis_ < 0) {
axis_ += params->dim_size();
}
MACE_CHECK(axis_ >= 0 && axis_ < params->dim_size(),
"axis is out of bound: ", axis_);
output_shape.insert(output_shape.end(), params->shape().begin(),
params->shape().begin() + axis_);
output_shape.insert(output_shape.end(), indices->shape().begin(),
indices->shape().end());
output_shape.insert(output_shape.end(),
params->shape().begin() + (axis_ + 1),
params->shape().end());
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard indices_guard(indices);
Tensor::MappingGuard params_guard(params);
Tensor::MappingGuard output_guard(output);
const int32_t *indices_data = indices->data<int32_t>();
const float *params_data = params->data<float>();
float *output_data = output->mutable_data<float>();
index_t axis_dim_size = params->dim(axis_);
index_t lhs_size = std::accumulate(params->shape().begin(),
params->shape().begin() + axis_, 1,
std::multiplies<index_t>());
index_t rhs_size =
std::accumulate(params->shape().begin() + (axis_ + 1),
params->shape().end(), 1, std::multiplies<index_t>());
index_t index_size = indices->size();
#pragma omp parallel for collapse(2)
for (index_t l = 0; l < lhs_size; ++l) {
for (index_t idx = 0; idx < index_size; ++idx) {
MACE_ASSERT(indices_data[idx] < axis_dim_size, "idx out of bound: ",
indices_data[idx]);
memcpy(
output_data + ((l * index_size) + idx) * rhs_size,
params_data + ((l * axis_dim_size) + indices_data[idx]) * rhs_size,
sizeof(float) * rhs_size);
}
}
if (std::fabs(y_ - 1.0) > 1e-6) {
#pragma omp parallel for
for (index_t i = 0; i < output->size(); ++i) {
output_data[i] *= y_;
}
}
return MACE_SUCCESS;
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_GATHER_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/gather.h"
namespace mace {
namespace ops {
void Register_Gather(OperatorRegistry *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Gather")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
GatherOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_GATHER_H_
#define MACE_OPS_GATHER_H_
#include "mace/core/operator.h"
#include "mace/kernels/gather.h"
namespace mace {
namespace ops {
template<DeviceType D, class T>
class GatherOp : public Operator<D, T> {
public:
GatherOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("axis", 0),
OperatorBase::GetOptionalArg<float>("y", 1.0)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *params = this->Input(PARAMS);
const Tensor *indices = this->Input(INDICES);
Tensor *output = this->Output(OUTPUT);
return functor_(params, indices, output, future);
}
private:
kernels::GatherFunctor<D, T> functor_;
protected:
MACE_OP_INPUT_TAGS(PARAMS, INDICES);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_GATHER_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/kernels/gather.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void GatherBenchmark(int iters,
index_t n,
index_t index_len,
index_t vocab_len,
index_t embedding_len) {
mace::testing::StopTiming();
static unsigned int seed = time(NULL);
OpsTestNet net;
std::vector<int32_t> index(index_len);
for (int i = 0; i < index_len; ++i) {
index[i] = rand_r(&seed) % vocab_len;
}
net.AddInputFromArray<D, int32_t>("Indices", {n, index_len}, index);
net.AddRandomInput<D, T>("Params", {vocab_len, embedding_len});
OpDefBuilder("Gather", "GatherTest")
.Input("Params")
.Input("Indices")
.AddIntArg("axis", 0)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 2; ++i) {
net.RunOp(D);
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Sync();
}
}
} // namespace
#define MACE_BM_GATHER_MACRO(N, IND, VOC, EMBED, TYPE, DEVICE) \
static void \
MACE_BM_GATHER##_##N##_##IND##_##VOC##_##EMBED##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * IND * EMBED; \
mace::testing::MaccProcessed(0); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
GatherBenchmark<DEVICE, TYPE>(iters, N, IND, VOC, EMBED); \
} \
MACE_BENCHMARK( \
MACE_BM_GATHER##_##N##_##IND##_##VOC##_##EMBED##_##TYPE##_##DEVICE)
#define MACE_BM_GATHER(N, INDEX, VOCAB, EMBEDDING) \
MACE_BM_GATHER_MACRO(N, INDEX, VOCAB, EMBEDDING, float, CPU);
MACE_BM_GATHER(1, 7, 48165, 256);
MACE_BM_GATHER(1, 20, 48165, 256);
MACE_BM_GATHER(1, 100, 48165, 256);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class GatherOpTest : public OpsTestBase {};
namespace {
void TestGather(const std::vector<index_t> &weight_shape,
const std::vector<float> &weight,
const std::vector<index_t> &input_shape,
const std::vector<int32_t> &input,
const int axis,
const float y,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Params", weight_shape, weight);
net.AddInputFromArray<CPU, int32_t>("Indices", input_shape, input);
OpDefBuilder("Gather", "GatherTest")
.Input("Params")
.Input("Indices")
.AddIntArg("axis", axis)
.AddFloatArg("y", y)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(CPU);
auto expected = CreateTensor<float>(output_shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace
TEST_F(GatherOpTest, CPUScalarIndex) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{}, {5}, 0, 2.0, {2}, {20, 22});
}
TEST_F(GatherOpTest, CPURank1Index) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{3}, {2, 4, 6}, 0, 1.0, {3, 2}, {4, 5, 8, 9, 12, 13});
}
TEST_F(GatherOpTest, CPURank1IndexWithAxis1) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{1}, {1}, 1, 1.0, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19});
}
TEST_F(GatherOpTest, CPURankHighIndex) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{1, 3}, {2, 4, 6}, 0, 1.0, {1, 3, 2}, {4, 5, 8, 9, 12, 13});
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册