提交 db658bec 编写于 作者: Z zhangwen31

[arm][kernel] refactor: elementwise-sub uses template now

上级 b059fb33
......@@ -132,24 +132,25 @@ void ElementwiseAddActivationCompute::Run() {
}
}
void ElementwiseSubCompute::Run() {
auto& param = Param<operators::ElementwiseParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
template <typename T, PrecisionType PType>
void ElementwiseSubCompute<T, PType>::Run() {
auto& param = this->template Param<operators::ElementwiseParam>();
const T* x_data = param.X->template data<T>();
const T* y_data = param.Y->template data<T>();
T* out_data = param.Out->template mutable_data<T>();
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_broadcast(
lite::arm::math::elementwise_sub_broadcast<T>(
y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_broadcast(
lite::arm::math::elementwise_sub_broadcast<T>(
x_data, y_data, out_data, pre, n, post);
} else {
lite::arm::math::elementwise_sub(
lite::arm::math::elementwise_sub<T>(
x_data, y_data, out_data, x_dims.production());
}
}
......@@ -419,12 +420,10 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_sub,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseSubCompute,
def)
using elementwise_sub_float_t =
paddle::lite::kernels::arm::ElementwiseSubCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(
elementwise_sub, kARM, kFloat, kNCHW, elementwise_sub_float_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
......
......@@ -38,8 +38,8 @@ class ElementwiseAddActivationCompute
virtual ~ElementwiseAddActivationCompute() = default;
};
class ElementwiseSubCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
template <typename T, PrecisionType PType>
class ElementwiseSubCompute : public KernelLite<TARGET(kARM), PType> {
public:
void Run() override;
......
......@@ -27,7 +27,7 @@ using param_t = operators::ElementwiseParam;
using grad_param_t = operators::ElementwiseGradParam;
using kernel_add_t = ElementwiseAddCompute<float, PRECISION(kFloat)>;
using grad_kernel_add_t = ElementwiseAddGradCompute;
using kernel_sub_t = ElementwiseSubCompute;
using kernel_sub_t = ElementwiseSubCompute<float, PRECISION(kFloat)>;
using grad_kernel_sub_t = ElementwiseSubGradCompute;
void elementwise_common(grad_param_t& param, // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册