提交 dfcd059b 编写于 作者: Z zhangwen31

[arm][kernel][math] feat: add i32 and i64 support for elementwise sub

上级 db658bec
...@@ -38,6 +38,12 @@ static T naive_add(T l, T r) { ...@@ -38,6 +38,12 @@ static T naive_add(T l, T r) {
return l + r; return l + r;
} }
// todo: remove this function when all elementwise sub works
template <typename T>
static T naive_sub(T l, T r) {
return l - r;
}
// todo: use arm intrinsics // todo: use arm intrinsics
template <> template <>
void elementwise_add<int32_t>(const int32_t* dinx, void elementwise_add<int32_t>(const int32_t* dinx,
...@@ -472,6 +478,25 @@ void elementwise_add_grad_broadcast<float>(const float* dout_grad, ...@@ -472,6 +478,25 @@ void elementwise_add_grad_broadcast<float>(const float* dout_grad,
} }
} }
} }
// todo: use arm intrinsics
template <>
void elementwise_sub<int32_t>(const int32_t* dinx,
const int32_t* diny,
int32_t* dout,
int num) {
naive_elementwise_op<int32_t>(dinx, diny, dout, num, naive_sub<int32_t>);
}
// todo: use arm intrinsics
template <>
void elementwise_sub<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int num) {
naive_elementwise_op<int64_t>(dinx, diny, dout, num, naive_sub<int64_t>);
}
template <> template <>
void elementwise_sub<float>(const float* dinx, void elementwise_sub<float>(const float* dinx,
const float* diny, const float* diny,
...@@ -572,6 +597,30 @@ void elementwise_sub_relu<float>(const float* dinx, ...@@ -572,6 +597,30 @@ void elementwise_sub_relu<float>(const float* dinx,
} }
} }
// todo: use arm intrinsics
template <>
void elementwise_sub_broadcast<int32_t>(const int32_t* dinx,
const int32_t* diny,
int32_t* dout,
int batch,
int channels,
int num) {
naive_elementwise_op_broadcast<int32_t>(
dinx, diny, dout, batch, channels, num, naive_sub<int32_t>);
}
// todo: use arm intrinsics
template <>
void elementwise_sub_broadcast<int64_t>(const int64_t* dinx,
const int64_t* diny,
int64_t* dout,
int batch,
int channels,
int num) {
naive_elementwise_op_broadcast<int64_t>(
dinx, diny, dout, batch, channels, num, naive_sub<int64_t>);
}
template <> template <>
void elementwise_sub_broadcast<float>(const float* dinx, void elementwise_sub_broadcast<float>(const float* dinx,
const float* diny, const float* diny,
......
...@@ -429,6 +429,26 @@ REGISTER_LITE_KERNEL( ...@@ -429,6 +429,26 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
using elementwise_sub_int32_t =
paddle::lite::kernels::arm::ElementwiseSubCompute<int32_t,
PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(
elementwise_sub, kARM, kInt32, kNCHW, elementwise_sub_int32_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
using elementwise_sub_int64_t =
paddle::lite::kernels::arm::ElementwiseSubCompute<int64_t,
PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(
elementwise_sub, kARM, kInt64, kNCHW, elementwise_sub_int64_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
fusion_elementwise_sub_activation, fusion_elementwise_sub_activation,
kARM, kARM,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册