提交 af89b659 编写于 作者: H hong19860320

add arm kernel for fusion_elementwise_add_activation op

test=develop
上级 ce6c24e6
......@@ -65,9 +65,61 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
}
template <>
void elementwise_add_axis<float>(const float* dinx, const float* diny,
float* dout, int batch, int channels,
int num) {
void elementwise_add_relu<float>(const float* dinx, const float* diny,
float* dout, int num) {
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* dinx_ptr = dinx + (i << 4);
const float* diny_ptr = diny + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t dinx0 = vld1q_f32(dinx_ptr);
float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4);
float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8);
float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12);
float32x4_t diny0 = vld1q_f32(diny_ptr);
float32x4_t diny1 = vld1q_f32(diny_ptr + 4);
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
dinx0 = vaddq_f32(dinx0, diny0);
dinx1 = vaddq_f32(dinx1, diny1);
dinx2 = vaddq_f32(dinx2, diny2);
dinx3 = vaddq_f32(dinx3, diny3);
// relu
dinx0 = vmaxq_f32(dinx0, vzero);
dinx1 = vmaxq_f32(dinx1, vzero);
dinx2 = vmaxq_f32(dinx2, vzero);
dinx3 = vmaxq_f32(dinx3, vzero);
vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, dinx3);
}
if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4);
const float* diny_ptr = diny + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
float tmp = *dinx_ptr + *diny_ptr;
*dout_ptr = tmp > 0.f ? tmp : 0.f;
dout_ptr++;
dinx_ptr++;
diny_ptr++;
}
}
}
template <>
void elementwise_add_broadcast<float>(const float* dinx, const float* diny,
float* dout, int batch, int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
......@@ -127,6 +179,82 @@ void elementwise_add_axis<float>(const float* dinx, const float* diny,
}
}
template <>
void elementwise_add_relu_broadcast<float>(const float* dinx, const float* diny,
float* dout, int batch, int channels,
int num) {
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const float* din_ptr = dinx + offset;
const float diny_data = diny[j];
float* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t rb = vdupq_n_f32(diny_data);
for (int k = 0; k < cnt; ++k) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
din2 = vaddq_f32(din2, rb);
din3 = vaddq_f32(din3, rb);
// relu
din0 = vmaxq_f32(din0, vzero);
din1 = vmaxq_f32(din1, vzero);
din2 = vmaxq_f32(din2, vzero);
din3 = vmaxq_f32(din3, vzero);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
// relu
din0 = vmaxq_f32(din0, vzero);
din1 = vmaxq_f32(din1, vzero);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
din0 = vaddq_f32(din0, rb);
// relu
din0 = vmaxq_f32(din0, vzero);
vst1q_f32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (int p = 0; p < remain; p++) {
float tmp = *din_ptr + diny_data;
*dout_ptr = tmp > 0.f ? tmp : 0.f;
dout_ptr++;
din_ptr++;
}
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -23,8 +23,15 @@ template <typename T>
void elementwise_add(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_add_axis(const T* dinx, const T* diny, T* dout, int batch,
int channels, int num);
void elementwise_add_relu(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_add_broadcast(const T* dinx, const T* diny, T* dout, int batch,
int channels, int num);
template <typename T>
void elementwise_add_relu_broadcast(const T* dinx, const T* diny, T* dout,
int batch, int channels, int num);
} // namespace math
} // namespace arm
......
......@@ -11,7 +11,7 @@ cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(elementwise_compute_arm SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(concat_compute_arm SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm)
......@@ -24,7 +24,7 @@ lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_comput
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm)
lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm)
lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm)
lite_cc_test(test_elementwise_compute_arm SRCS elementwise_compute_test.cc DEPS elementwise_compute_arm)
lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm)
lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm)
lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm)
......@@ -40,7 +40,7 @@ set(arm_kernels
softmax_compute_arm
conv_compute_arm
batch_norm_compute_arm
elementwise_add_compute_arm
elementwise_compute_arm
pool_compute_arm
split_compute_arm
concat_compute_arm
......
......@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/elementwise_add_compute.h"
#include "paddle/fluid/lite/kernels/arm/elementwise_compute.h"
#include <string>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
......@@ -20,6 +21,30 @@ namespace lite {
namespace kernels {
namespace arm {
inline bool is_broadcast(const DDim& x_dims, const DDim& y_dims, int axis,
int* pre, int* n, int* post) {
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
if (x_dims.size() == y_dims.size()) {
return false;
}
*pre = 1;
*n = 1;
*post = 1;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
CHECK_EQ(x_dims[i + axis], y_dims[i]) << "Broadcast dimension mismatch.";
(*n) *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
return true;
}
void ElementwiseAddCompute::Run() {
auto& param = Param<operators::ElementwiseParam>();
const float* x_data = param.X->data<float>();
......@@ -28,27 +53,40 @@ void ElementwiseAddCompute::Run() {
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
if (x_dims.size() == y_dims.size()) {
int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_broadcast(x_data, y_data, out_data, pre, n,
post);
} else {
lite::arm::math::elementwise_add(x_data, y_data, out_data,
x_dims.production());
} else {
int batch = 1;
int channels = 1;
int num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
}
void ElementwiseAddActivationCompute::Run() {
auto& param = Param<operators::FusionElementwiseActivationParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
int axis = param.axis;
std::string act_type = param.act_type;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") {
lite::arm::math::elementwise_add_relu_broadcast(x_data, y_data, out_data,
pre, n, post);
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
} else {
if (act_type == "relu") {
lite::arm::math::elementwise_add_relu(x_data, y_data, out_data,
x_dims.production());
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
lite::arm::math::elementwise_add_axis(x_data, y_data, out_data, batch,
channels, num);
}
}
......@@ -63,3 +101,11 @@ REGISTER_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW,
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
fusion_elementwise_add_activation, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ElementwiseAddActivationCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -30,6 +30,14 @@ class ElementwiseAddCompute
virtual ~ElementwiseAddCompute() = default;
};
class ElementwiseAddActivationCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ElementwiseAddActivationCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/elementwise_add_compute.h"
#include "paddle/fluid/lite/kernels/arm/elementwise_compute.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
......@@ -37,7 +38,9 @@ TEST(elementwise_add_arm, init) {
}
template <typename dtype>
void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
void elementwise_compute_ref(const operators::ElementwiseParam& param,
const std::string elt_type,
const std::string act_type) {
const dtype* x_data = param.X->data<const dtype>();
const dtype* y_data = param.Y->data<const dtype>();
dtype* out_data = param.Out->mutable_data<dtype>();
......@@ -59,17 +62,52 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
}
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
// do elementwise add/sub/max...
if (elt_type == "add") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (elt_type == "sub") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr - diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else {
LOG(FATAL) << "unsupported Elementwise type: " << elt_type;
}
// do activation relu/sigmod...
if (act_type.size() > 0) {
if (act_type == "relu") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
dtype* dout_ptr = out_data + (i * channels + j) * num;
for (int k = 0; k < num; ++k) {
*dout_ptr = *dout_ptr > 0.0f ? *dout_ptr : 0.0f;
dout_ptr++;
}
}
}
} else {
LOG(FATAL) << "unsupported Activation type: " << elt_type;
}
}
}
......@@ -123,7 +161,7 @@ TEST(elementwise_add, compute) {
elementwise_add.SetParam(param);
elementwise_add.Run();
param.Out = &output_ref;
elementwise_add_compute_ref<float>(param);
elementwise_compute_ref<float>(param, "add", "");
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
......@@ -135,9 +173,91 @@ TEST(elementwise_add, compute) {
}
}
TEST(fusion_elementwise_add_activation_arm, retrive_op) {
auto fusion_elementwise_add_activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"fusion_elementwise_add_activation");
ASSERT_FALSE(fusion_elementwise_add_activation.empty());
ASSERT_TRUE(fusion_elementwise_add_activation.front());
}
TEST(fusion_elementwise_add_activation_arm, init) {
ElementwiseAddActivationCompute fusion_elementwise_add_activation;
ASSERT_EQ(fusion_elementwise_add_activation.precision(), PRECISION(kFloat));
ASSERT_EQ(fusion_elementwise_add_activation.target(), TARGET(kARM));
}
TEST(fusion_elementwise_add_activation_arm, compute) {
ElementwiseAddActivationCompute fusion_elementwise_add_activation;
operators::FusionElementwiseActivationParam param;
lite::Tensor x, y, output, output_ref;
for (auto act_type : {"relu"}) {
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 4, 11}) {
for (auto h : {1, 3, 4, 11}) {
for (auto w : {1, 3, 4, 11}) {
for (auto axis : {-1, 0, 1, 2, 3}) {
for (auto yd :
{std::vector<int64_t>({n}), std::vector<int64_t>({c}),
std::vector<int64_t>({h}), std::vector<int64_t>({w}),
std::vector<int64_t>({n, c}), std::vector<int64_t>({c, h}),
std::vector<int64_t>({h, w}),
std::vector<int64_t>({n, c, h}),
std::vector<int64_t>({c, h, w}),
std::vector<int64_t>({n, c, h, w})}) {
auto x_dim = DDim(std::vector<int64_t>({n, c, h, w}));
auto y_dim = DDim(yd);
int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis;
if (axis_t + y_dim.size() > 4) continue;
bool flag = false;
for (int i = 0; i < y_dim.size(); i++) {
if (x_dim[i + axis_t] != y_dim[i]) flag = true;
}
if (flag) continue;
x.Resize(x_dim);
y.Resize(y_dim);
output.Resize(x_dim);
output_ref.Resize(x_dim);
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x_dim.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
x_data[i] = i * sign;
}
for (int i = 0; i < y_dim.production(); i++) {
float sign = i % 2 == 0 ? 0.5f : -0.5f;
y_data[i] = i * sign;
}
param.X = &x;
param.Y = &y;
param.axis = axis;
param.Out = &output;
param.act_type = act_type;
fusion_elementwise_add_activation.SetParam(param);
fusion_elementwise_add_activation.Run();
param.Out = &output_ref;
elementwise_compute_ref<float>(param, "add", act_type);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(fusion_elementwise_add_activation, kARM, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册