提交 f5eaee4e 编写于 作者: Z zhupengyang 提交者: hong19860320

add elementwise_add kernel and unit test (#17825)

* add elementwise_add kernel
test=develop

* add elementwise_add unit test
test=develop

* enable arm neon
test=develop

* remove Repetitive funcs
test=develop
上级 0328da9c
...@@ -6,4 +6,4 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) ...@@ -6,4 +6,4 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM))
return() return()
endif() endif()
cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc scale.cc DEPS ${lite_kernel_deps} eigen3) cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc scale.cc elementwise.cc DEPS ${lite_kernel_deps} eigen3)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/elementwise.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
int num) {
int cnt = num >> 4;
int remain = num % 16;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* dinx_ptr = dinx + (i << 4);
const float* diny_ptr = diny + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t dinx0 = vld1q_f32(dinx_ptr);
float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4);
float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8);
float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12);
float32x4_t diny0 = vld1q_f32(diny_ptr);
float32x4_t diny1 = vld1q_f32(diny_ptr + 4);
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
float32x4_t vsum0 = vaddq_f32(dinx0, diny0);
float32x4_t vsum1 = vaddq_f32(dinx1, diny1);
float32x4_t vsum2 = vaddq_f32(dinx2, diny2);
float32x4_t vsum3 = vaddq_f32(dinx3, diny3);
vst1q_f32(dout_ptr, vsum0);
vst1q_f32(dout_ptr + 4, vsum1);
vst1q_f32(dout_ptr + 8, vsum2);
vst1q_f32(dout_ptr + 12, vsum3);
}
if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4);
const float* diny_ptr = diny + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *dinx_ptr + *diny_ptr;
dout_ptr++;
dinx_ptr++;
diny_ptr++;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void elementwise_add(const T* dinx, const T* diny, T* dout, int num);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -9,16 +9,19 @@ cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) ...@@ -9,16 +9,19 @@ cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm)
set(arm_kernels set(arm_kernels
fc_compute_arm fc_compute_arm
relu_compute_arm relu_compute_arm
mul_compute_arm mul_compute_arm
scale_compute_arm scale_compute_arm
softmax_compute_arm) softmax_compute_arm
elementwise_add_compute_arm)
set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/elementwise_add_compute.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void ElementwiseAddCompute::Run() {
auto& param = Param<operators::ElementwiseParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
int n = param.X->dims().production();
lite::arm::math::elementwise_add(x_data, y_data, out_data, n);
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ElementwiseAddCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ElementwiseAddCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ElementwiseAddCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/elementwise_add_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
TEST(elementwise_add_arm, retrive_op) {
auto elementwise_add =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"elementwise_add");
ASSERT_FALSE(elementwise_add.empty());
ASSERT_TRUE(elementwise_add.front());
}
TEST(elementwise_add_arm, init) {
ElementwiseAddCompute elementwise_add;
ASSERT_EQ(elementwise_add.precision(), PRECISION(kFloat));
ASSERT_EQ(elementwise_add.target(), TARGET(kARM));
}
template <typename dtype>
void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
const dtype* x_data = param.X->data<const dtype>();
const dtype* y_data = param.Y->data<const dtype>();
dtype* out_data = param.Out->mutable_data<dtype>();
DDim dim = param.X->dims();
ASSERT_EQ(dim.data(), param.Out->dims().data());
for (int i = 0; i < dim.production(); i++) {
out_data[i] = x_data[i] + y_data[i];
}
}
TEST(elementwise_add, compute) {
ElementwiseAddCompute elementwise_add;
operators::ElementwiseParam param;
lite::Tensor x, y, out, out_ref;
x.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5})));
y.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5})));
out.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5})));
out_ref.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5})));
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
auto* out_data = out.mutable_data<float>();
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = y_data[i] = i;
}
param.X = &x;
param.Y = &y;
param.Out = &out;
elementwise_add.SetParam(param);
elementwise_add.Run();
param.Out = &out_ref;
elementwise_add_compute_ref<float>(param);
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册