未验证 提交 678a7c85 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add scatter op on arm. test=develop (#4290)

* add scatter op on arm. test=develop

* fix cmake error. test=develop

* test=develop
上级 d353b126
...@@ -130,5 +130,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -130,5 +130,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
lstm.cc lstm.cc
clip.cc clip.cc
pixel_shuffle.cc pixel_shuffle.cc
scatter.cc
DEPS ${lite_kernel_deps} context tensor) DEPS ${lite_kernel_deps} context tensor)
endif() endif()
...@@ -54,6 +54,7 @@ ...@@ -54,6 +54,7 @@
#include "lite/backends/arm/math/reduce_mean.h" #include "lite/backends/arm/math/reduce_mean.h"
#include "lite/backends/arm/math/reduce_prod.h" #include "lite/backends/arm/math/reduce_prod.h"
#include "lite/backends/arm/math/scale.h" #include "lite/backends/arm/math/scale.h"
#include "lite/backends/arm/math/scatter.h"
#include "lite/backends/arm/math/sequence_expand.h" #include "lite/backends/arm/math/sequence_expand.h"
#include "lite/backends/arm/math/sequence_pool.h" #include "lite/backends/arm/math/sequence_pool.h"
#include "lite/backends/arm/math/sequence_pool_grad.h" #include "lite/backends/arm/math/sequence_pool_grad.h"
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/backends/arm/math/scatter.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void scatter<float>(const int64_t* indexs,
const float* src,
float* dst,
int index_size,
int num,
int size,
bool overwrite) {
for (int i = 0; i < num; i++) {
const float* din = src + indexs[i] * size;
memcpy(dst, din, sizeof(float) * size);
dst += size;
}
if (overwrite) {
for (int i = num; i < index_size; i++) {
const float* din = src + indexs[i] * size;
float* dout = dst + indexs[i] * size;
memcpy(dout, din, sizeof(float) * size);
}
} else {
int cnt = size >> 3;
int rem = size & 7;
for (int i = num; i < index_size; i++) {
const float* din = src + indexs[i] * size;
float* dout = dst + indexs[i] * size;
for (int j = 0; j < cnt; j++) {
float32x4_t va0 = vld1q_f32(din);
float32x4_t vb0 = vld1q_f32(dout);
float32x4_t va1 = vld1q_f32(din + 4);
float32x4_t vb1 = vld1q_f32(dout + 4);
vb0 = vaddq_f32(va0, vb0);
vb1 = vaddq_f32(va1, vb1);
din += 8;
vst1q_f32(dout, vb0);
vst1q_f32(dout + 4, vb0);
dout += 8;
}
for (int j = 0; j < rem; j++) {
dout[0] += *din++;
dout++;
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void scatter(const int64_t* indexs,
const T* updates,
T* dst,
int index_size,
int num,
int size,
bool overwrite);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -79,8 +79,10 @@ add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposal ...@@ -79,8 +79,10 @@ add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposal
add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(pixel_shuffle_compute_arm ARM extra SRCS pixel_shuffle_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(pixel_shuffle_compute_arm ARM extra SRCS pixel_shuffle_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(scatter_compute_arm ARM extra SRCS scatter_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_expand_as_compute_arm ARM extra SRCS sequence_expand_as_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_expand_as_compute_arm ARM extra SRCS sequence_expand_as_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific # for OCR specific
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/scatter_compute.h"
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void ScatterCompute::Run() {
auto& param = this->template Param<operators::ScatterParam>();
const float* updates_data = param.updates->template data<float>();
const int64_t* indexs_data = param.indexs->template data<int64_t>();
float* output_data = param.output->template mutable_data<float>();
bool overwrite = param.overwrite;
int index_size = param.indexs->dims()[0];
auto in_dims = param.x->dims();
int num = 1;
for (int i = 1; i < in_dims.size(); i++) {
num *= in_dims[i];
}
lite::arm::math::scatter(indexs_data,
updates_data,
output_data,
index_size,
in_dims[0],
num,
overwrite);
if (!param.x->lod().empty()) {
param.output->set_lod(param.x->lod());
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(scatter,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ScatterCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Updates",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ScatterCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ScatterCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -121,6 +121,7 @@ add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${ ...@@ -121,6 +121,7 @@ add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${
add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS}) add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS})
add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS}) add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS})
add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS}) add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS})
add_operator(scatter extra SRCS scatter_op.cc DEPS ${op_DEPS})
# for OCR specific # for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
...@@ -294,6 +294,16 @@ struct ScaleParam : ParamBase { ...@@ -294,6 +294,16 @@ struct ScaleParam : ParamBase {
} }
}; };
// For Scatter OP
struct ScatterParam : ParamBase {
lite::Tensor* x{};
lite::Tensor* indexs{};
lite::Tensor* updates{};
lite::Tensor* output{};
bool overwrite{true};
};
// For Softmax op // For Softmax op
struct SoftmaxParam : ParamBase { struct SoftmaxParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/scatter_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ScatterOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
return true;
}
bool ScatterOp::InferShapeImpl() const {
auto index_dims = param_.indexs->dims();
auto update_dims = param_.updates->dims();
auto input_dims = param_.x->dims();
for (int i = 1; i < update_dims.size(); i++) {
CHECK_EQ_OR_FALSE(update_dims[i], input_dims[i]);
}
CHECK_EQ_OR_FALSE(index_dims.size(), 1L);
param_.output->Resize(input_dims);
return true;
}
bool ScatterOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
auto x = op_desc.Input("X").front();
auto indexs = op_desc.Input("Ids").front();
auto updates = op_desc.Input("Updates").front();
auto output = op_desc.Output("Out").front();
if (op_desc.HasAttr("overwrite")) {
param_.overwrite = op_desc.GetAttr<bool>("overwrite");
} else {
param_.overwrite = true;
}
param_.x = scope->FindVar(x)->GetMutable<Tensor>();
param_.indexs = scope->FindVar(indexs)->GetMutable<Tensor>();
param_.updates = scope->FindVar(updates)->GetMutable<Tensor>();
param_.output = scope->FindMutableTensor(output);
CHECK(param_.x);
CHECK(param_.indexs);
CHECK(param_.updates);
CHECK(param_.output);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(scatter, paddle::lite::operators::ScatterOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class ScatterOp : public OpLite {
public:
ScatterOp() {}
explicit ScatterOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "Scatter"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.x->dims());
ch->output_shape = ch->DimToStr(param_.output->dims());
ch->macs = param_.x->numel() * 1.f;
}
#endif
private:
mutable ScatterParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -30,7 +30,7 @@ void bgra_to_tensor_hwc(const uint8_t* bgr, ...@@ -30,7 +30,7 @@ void bgra_to_tensor_hwc(const uint8_t* bgr,
float b_scales = scales[2]; float b_scales = scales[2];
int dim8 = width >> 3; int dim8 = width >> 3;
int remain = wwidth - (dim8 << 3); int remain = width - (dim8 << 3);
float32x4_t vrmean = vdupq_n_f32(r_means); float32x4_t vrmean = vdupq_n_f32(r_means);
float32x4_t vgmean = vdupq_n_f32(g_means); float32x4_t vgmean = vdupq_n_f32(g_means);
......
...@@ -66,6 +66,7 @@ if(LITE_BUILD_EXTRA) ...@@ -66,6 +66,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_ctc_align_compute SRCS ctc_align_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_ctc_align_compute SRCS ctc_align_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_clip_compute SRCS clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_clip_compute SRCS clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_pixel_shuffle_compute SRCS pixel_shuffle_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pixel_shuffle_compute SRCS pixel_shuffle_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_scatter_compute SRCS scatter_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_expand_as_compute SRCS sequence_expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_expand_as_compute SRCS sequence_expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
# for training kernel # for training kernel
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
void scatter_basic(const int64_t* indexs,
const float* src,
float* dst,
int index_size,
int num,
int size,
bool overwrite) {
for (int i = 0; i < num; i++) {
const float* din = src + indexs[i] * size;
memcpy(dst, din, sizeof(float) * size);
dst += size;
}
if (overwrite) {
for (int i = num; i < index_size; i++) {
const float* din = src + indexs[i] * size;
float* dout = dst + indexs[i] * size;
memcpy(dout, din, sizeof(float) * size);
}
} else {
for (int i = num; i < index_size; i++) {
const float* din = src + indexs[i] * size;
float* dout = dst + indexs[i] * size;
for (int j = 0; j < size; j++) {
dout[j] += din[j];
}
}
}
}
class ScatterComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "x";
std::string indexs_ = "indexs";
std::string updates_ = "updates";
std::string output_ = "out";
DDim up_dims_{{1}};
DDim id_dims_{{1}};
DDim x_dims_{{1}};
int index_size_ = 0;
bool overwrite_ = false;
public:
ScatterComputeTester(const Place& place,
const std::string& alias,
DDim up_dims,
DDim id_dims,
DDim x_dims,
bool overwrite,
int index_size)
: TestCase(place, alias),
up_dims_(up_dims),
id_dims_(id_dims),
x_dims_(x_dims),
index_size_(index_size),
overwrite_(overwrite) {}
void RunBaseline(Scope* scope) override {
auto* indexs_t = scope->FindMutableTensor(indexs_);
auto* updates_t = scope->FindMutableTensor(updates_);
const auto* indexs_data = indexs_t->data<int64_t>();
const auto* updates_data = updates_t->data<float>();
auto* out = scope->NewTensor(output_);
out->Resize(x_dims_);
auto* out_data = out->mutable_data<float>();
int in_n = x_dims_[0];
int in_c = x_dims_[1];
int in_h = x_dims_[2];
int in_w = x_dims_[3];
int size = in_c * in_h * in_w;
scatter_basic(indexs_data,
updates_data,
out_data,
index_size_,
in_n,
size,
overwrite_);
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("scatter");
op_desc->SetInput("X", {input_});
op_desc->SetInput("Ids", {indexs_});
op_desc->SetInput("Updates", {updates_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("overwrite", overwrite_);
}
void PrepareData() override {
std::vector<float> data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
data[i] = i * 1.0;
}
SetCommonTensor(input_, x_dims_, data.data());
std::vector<float> update(up_dims_.production());
for (int i = 0; i < up_dims_.production(); i++) {
update[i] = i * 1.0;
}
SetCommonTensor(updates_, up_dims_, update.data());
std::vector<int64_t> index(id_dims_.production());
for (int i = 0; i < id_dims_.production(); i++) {
index[i] = i;
}
SetCommonTensor(indexs_, id_dims_, index.data());
}
};
void test_scatter(Place place) {
for (auto n : {1, 3}) {
for (auto c : {1, 2}) {
for (auto h : {1, 3}) {
for (auto w : {1, 3}) {
for (bool overwrite : {false, true}) {
auto x_dims = DDim(std::vector<int64_t>({n, c, h, w}));
auto up_dims = DDim(std::vector<int64_t>({n, c, h, w}));
auto id_dims = DDim(std::vector<int64_t>({n}));
std::unique_ptr<arena::TestCase> tester(new ScatterComputeTester(
place, "def", up_dims, id_dims, x_dims, overwrite, n));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
}
TEST(Scatter, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_scatter(place);
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册