未验证 提交 c57097ff 编写于 作者: Y yiicy 提交者: GitHub

add where_index op, test=develop (#3789)

* [ARM] add where_index op, test=develop

* [ARM] add where_index op, test=develop
上级 4e9852e7
...@@ -16,3 +16,8 @@ add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${li ...@@ -16,3 +16,8 @@ add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${li
add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps}) add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps}) add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps}) add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps})
add_kernel(where_index_compute_host Host extra SRCS where_index_compute.cc DEPS ${lite_kernel_deps})
if(LITE_BUILD_EXTRA)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/where_index_compute.h"
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
static void where_index_rank4(const int64_t* true_index,
int true_num,
const int64_t* stride,
int64_t* out) {
int cnt = true_num >> 1;
register int64_t stride0 = stride[0];
register int64_t stride1 = stride[1];
register int64_t stride2 = stride[2];
register int64_t stride3 = stride[3];
for (int i = 0; i < cnt; ++i) {
int64_t index0 = true_index[i * 2];
int64_t index1 = true_index[i * 2 + 1];
int out_index = i * 8;
// rank0
register int64_t oindex0 = index0 / stride0;
register int64_t oindex1 = index1 / stride0;
out[out_index] = oindex0;
index0 -= oindex0 * stride0;
index1 -= oindex1 * stride0;
out[out_index + 4] = oindex1;
out_index++;
// rank1
oindex0 = index0 / stride1;
oindex1 = index1 / stride1;
out[out_index] = oindex0;
index0 -= oindex0 * stride1;
index1 -= oindex1 * stride1;
out[out_index + 4] = oindex1;
out_index++;
// rank2
oindex0 = index0 / stride2;
oindex1 = index1 / stride2;
out[out_index] = oindex0;
index0 -= oindex0 * stride2;
index1 -= oindex1 * stride2;
out[out_index + 4] = oindex1;
out_index++;
// rank3
oindex0 = index0 / stride3;
oindex1 = index1 / stride3;
out[out_index] = oindex0;
out[out_index + 4] = oindex1;
}
// remain
for (int r = cnt * 2; r < true_num; ++r) {
int out_index = r * 4;
int64_t index = true_index[r];
for (int i = 0; i < 4; ++i) {
out[out_index + i] = index / stride[i];
index -= out[out_index + i] * stride[i];
}
}
}
inline void where_index_rank1(const int64_t* true_index,
int true_num,
int64_t* out) {
memcpy(out, true_index, true_num * sizeof(int64_t));
}
static void where_index_rankn(const int64_t* true_index,
int true_num,
const int64_t* stride,
int rank,
int64_t* out) {
int out_index = 0;
for (int i = 0; i < true_num; ++i) {
int64_t index = true_index[i];
for (int r = 0; r < rank; ++r) {
out[out_index] = index / stride[r];
index -= out[out_index++] * stride[r];
}
}
}
template <typename T>
void WhereIndexKernel(const operators::WhereIndexParam& param) {
auto* input = param.input;
auto* output = param.output;
auto dims = input->dims();
auto numel = dims.production();
int64_t rank = static_cast<int64_t>(dims.size());
const T* cond_data = input->template data<T>();
int64_t true_num = 0;
std::vector<int64_t> true_index(numel);
for (auto i = 0; i < numel; i++) {
if (static_cast<bool>(cond_data[i])) {
true_index[true_num] = i;
true_num++;
}
}
output->Resize({true_num, rank});
if (true_num == 0) {
return;
}
auto* out_ptr = output->template mutable_data<int64_t>();
std::vector<int64_t> stride(rank);
stride[rank - 1] = 1;
for (int i = rank - 2; i >= 0; i--) {
stride[i] = stride[i + 1] * dims[i + 1];
}
if (rank == 1) {
where_index_rank1(true_index.data(), true_num, out_ptr);
} else if (rank == 4) {
where_index_rank4(true_index.data(), true_num, stride.data(), out_ptr);
} else {
where_index_rankn(
true_index.data(), true_num, stride.data(), rank, out_ptr);
}
}
void WhereIndexCompute::Run() {
auto& param = this->Param<operators::WhereIndexParam>();
switch (param.input->precision()) {
case PRECISION(kFloat):
WhereIndexKernel<float>(param);
break;
case PRECISION(kInt32):
WhereIndexKernel<int32_t>(param);
break;
case PRECISION(kInt64):
WhereIndexKernel<int64_t>(param);
break;
case PRECISION(kInt8):
WhereIndexKernel<int8_t>(param);
break;
case PRECISION(kBool):
WhereIndexKernel<bool>(param);
break;
default:
LOG(FATAL) << "WhereIndex does not implement for the "
<< "input type:" << static_cast<int>(param.input->precision());
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
using whereindex = paddle::lite::kernels::host::WhereIndexCompute;
REGISTER_LITE_KERNEL(where_index, kHost, kAny, kAny, whereindex, def)
.BindInput("Condition",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/where_index_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class WhereIndexCompute : public KernelLite<TARGET(kHost), PRECISION(kAny)> {
public:
using param_t = operators::WhereIndexParam;
void Run() override;
virtual ~WhereIndexCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/where_index_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "lite/core/context.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
template <typename T>
void where_index_compute_ref(lite::Tensor* condition, lite::Tensor* out) {
auto dims = condition->dims();
auto numel = condition->numel();
const int64_t rank = static_cast<int64_t>(dims.size());
const T* cond_data = condition->data<T>();
std::vector<int64_t> true_index;
for (auto i = 0; i < numel; i++) {
if (static_cast<bool>(cond_data[i])) {
true_index.push_back(i);
}
}
int64_t true_num = static_cast<int64_t>(true_index.size());
out->Resize({true_num, rank});
int64_t* out_ptr = out->mutable_data<int64_t>();
if (true_num == 0) {
return;
}
std::vector<int64_t> stride(rank);
stride[rank - 1] = 1;
for (int i = rank - 2; i >= 0; i--) {
stride[i] = stride[i + 1] * dims[i + 1];
}
for (int i = 0; i < true_num; ++i) {
int64_t index = true_index[i];
for (int j = 0; j < rank; ++j) {
out_ptr[i * rank + j] = index / stride[j];
index -= out_ptr[i * rank + j] * stride[j];
}
}
}
TEST(where_index, init) {
WhereIndexCompute where_index;
ASSERT_EQ(where_index.precision(), PRECISION(kAny));
ASSERT_EQ(where_index.target(), TARGET(kHost));
}
TEST(where_index, retrive_op) {
auto where_index =
KernelRegistry::Global().Create<TARGET(kHost), PRECISION(kAny)>(
"where_index");
ASSERT_FALSE(where_index.empty());
ASSERT_TRUE(where_index.front());
}
TEST(where_index, compute) {
paddle::lite::DeviceInfo::Init();
WhereIndexCompute where_index;
operators::WhereIndexParam param;
lite::Tensor input;
lite::Tensor output;
lite::Tensor output_ref;
param.input = &input;
param.output = &output;
where_index.SetParam(param);
for (auto& n : {1, 2, 4}) {
for (auto& c : {1, 3, 21, 32}) {
for (auto& h : {1, 5, 63}) {
for (auto& w : {1, 5, 64}) {
for (auto& dim_size : {1, 2, 3, 4}) {
for (int i = 0; i < 5; ++i) {
std::vector<int64_t> in_shape;
in_shape.push_back(n);
in_shape.push_back(c);
in_shape.push_back(h);
in_shape.push_back(w);
int outer = 1;
for (int i = dim_size - 1; i < in_shape.size(); ++i) {
outer *= in_shape[i];
}
in_shape.resize(dim_size);
in_shape[dim_size - 1] = outer;
DDim indim(in_shape);
LOG(INFO) << "in dims: ";
for (int i = 0; i < dim_size; ++i) {
LOG(INFO) << in_shape[i];
}
input.Resize(indim);
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-1, 1);
if (i == 0) {
int* indata = input.mutable_data<int32_t>();
for (int i = 0; i < indim.production(); ++i) {
indata[i] = static_cast<int>(dist(engine) > 0);
}
where_index_compute_ref<int32_t>(&input, &output_ref);
} else if (i == 1) {
int64_t* indata = input.mutable_data<int64_t>();
for (int i = 0; i < indim.production(); ++i) {
indata[i] = static_cast<int64_t>(dist(engine) > 0);
}
where_index_compute_ref<int64_t>(&input, &output_ref);
} else if (i == 2) {
int8_t* indata = input.mutable_data<int8_t>();
for (int i = 0; i < indim.production(); ++i) {
indata[i] = static_cast<int8_t>(dist(engine) > 0);
}
where_index_compute_ref<int8_t>(&input, &output_ref);
} else if (i == 3) {
bool* indata = input.mutable_data<bool>();
for (int i = 0; i < indim.production(); ++i) {
indata[i] = dist(engine) > 0;
}
where_index_compute_ref<bool>(&input, &output_ref);
} else {
float* indata = input.mutable_data<float>();
for (int i = 0; i < indim.production(); ++i) {
indata[i] = dist(engine) > 0;
}
where_index_compute_ref<float>(&input, &output_ref);
}
where_index.Run();
const int64_t* outdata = output.data<int64_t>();
const int64_t* outdata_ref = output_ref.data<int64_t>();
CHECK_EQ(output.dims(), output_ref.dims())
<< "where_index out shape error! out_dim is not equal "
"to out_ref dim";
for (int i = 0; i < output.numel(); i++) {
if (std::abs(outdata[i] - outdata_ref[i]) > 0) {
LOG(FATAL) << "where_index cmp error, i: " << i
<< ", output_data: " << outdata[i]
<< ", output_ref_data: " << outdata_ref[i]
<< "input precision: "
<< static_cast<int>(input.precision());
}
}
}
}
}
}
}
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(where_index, kHost, kAny, kAny, def);
...@@ -137,6 +137,7 @@ add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS}) ...@@ -137,6 +137,7 @@ add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS})
add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS}) add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS})
add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS}) add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS}) add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
add_operator(where_index_op extra SRCS where_index_op.cc DEPS ${op_DEPS})
# for content-dnn specific # for content-dnn specific
add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS}) add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS})
add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS}) add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
......
...@@ -1568,6 +1568,11 @@ struct PixelShuffleParam : ParamBase { ...@@ -1568,6 +1568,11 @@ struct PixelShuffleParam : ParamBase {
lite::Tensor* output{nullptr}; lite::Tensor* output{nullptr};
int upscale_factor{1}; int upscale_factor{1};
}; };
struct WhereIndexParam : ParamBase {
const lite::Tensor* input{nullptr};
lite::Tensor* output{nullptr};
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/where_index_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool WhereIndexdOp::CheckShape() const {
CHECK_OR_FALSE(param_.input);
CHECK_OR_FALSE(param_.output);
CHECK_GE(param_.input->dims().size(), 1);
return true;
}
bool WhereIndexdOp::InferShapeImpl() const {
int64_t rank = static_cast<int64_t>(param_.input->dims().size());
int64_t numel = static_cast<int64_t>(param_.input->dims().production());
param_.output->Resize({numel, rank});
return true;
}
bool WhereIndexdOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
auto input = opdesc.Input("Condition").front();
auto output = opdesc.Output("Out").front();
CHECK(scope->FindVar(input));
CHECK(scope->FindVar(output));
param_.input = GetVar<lite::Tensor>(scope, input);
param_.output = GetMutableVar<lite::Tensor>(scope, output);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(where_index, paddle::lite::operators::WhereIndexdOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class WhereIndexdOp : public OpLite {
public:
WhereIndexdOp() {}
explicit WhereIndexdOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "where_index_op"; }
private:
mutable WhereIndexParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册