提交 32b5caa0 编写于 作者: Z zhangwen31

[arm][math]feat: elementwise_add support i32 and i64 now

上级 8018225c
...@@ -21,6 +21,41 @@ namespace lite { ...@@ -21,6 +21,41 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
// todo: remove this function when all elementwise_add works
template <typename T>
static void naive_elementwise_op(
const T* dinx, const T* diny, T* dout, int num, std::function<T(T, T)> op) {
for (int i = 0; i < num; ++i) {
*dout = op(*dinx, *diny);
++dinx;
++diny;
++dout;
}
}
// todo: remove this function when all elementwise_add works
template <typename T>
static T naive_add(T l, T r) {
return l + r;
}
// todo: use arm intrinsics
template <>
void elementwise_add<int32_t>(const int32_t* dinx,
const int32_t* diny,
int32_t* dout,
int num) {
naive_elementwise_op<int32_t>(dinx, diny, dout, num, naive_add<int32_t>);
}
// todo: use arm intrinsics
template <>
void elementwise_add<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int num) {
naive_elementwise_op<int64_t>(dinx, diny, dout, num, naive_add<int64_t>);
}
template <> template <>
void elementwise_add<float>(const float* dinx, void elementwise_add<float>(const float* dinx,
const float* diny, const float* diny,
...@@ -178,6 +213,54 @@ void elementwise_add_tanh<float>(const float* dinx, ...@@ -178,6 +213,54 @@ void elementwise_add_tanh<float>(const float* dinx,
} }
} }
// todo: remove this function when all elementwise_add works
template <typename T>
static void naive_elementwise_op_broadcast(const T* x_data,
const T* y_data,
T* out_data,
int batch,
int channels,
int num,
std::function<T(T, T)> op) {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const T* din_ptr = x_data + offset;
const T diny_data = y_data[j];
T* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = op(*din_ptr, diny_data);
dout_ptr++;
din_ptr++;
}
}
}
}
// todo: use arm intrinsics
template <>
void elementwise_add_broadcast<int32_t>(const int32_t* dinx,
const int32_t* diny,
int32_t* dout,
int batch,
int channels,
int num) {
naive_elementwise_op_broadcast<int32_t>(
dinx, diny, dout, batch, channels, num, naive_add<int32_t>);
}
// todo: use arm intrinsics
template <>
void elementwise_add_broadcast<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int batch,
int channels,
int num) {
naive_elementwise_op_broadcast<int64_t>(
dinx, diny, dout, batch, channels, num, naive_add<int64_t>);
}
template <> template <>
void elementwise_add_broadcast<float>(const float* dinx, void elementwise_add_broadcast<float>(const float* dinx,
const float* diny, const float* diny,
......
...@@ -387,6 +387,26 @@ REGISTER_LITE_KERNEL( ...@@ -387,6 +387,26 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_add_int32_t =
paddle::lite::kernels::arm::ElementwiseAddCompute<int32_t,
PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(
elementwise_add, kARM, kInt32, kNCHW, elementwise_add_int32_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
using elementwise_add_int64_t =
paddle::lite::kernels::arm::ElementwiseAddCompute<int64_t,
PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(
elementwise_add, kARM, kInt64, kNCHW, elementwise_add_int64_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
fusion_elementwise_add_activation, fusion_elementwise_add_activation,
kARM, kARM,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册