提交 db658bec 编写于 作者: Z zhangwen31

[arm][kernel] refactor: elementwise-sub uses template now

上级 b059fb33
...@@ -132,24 +132,25 @@ void ElementwiseAddActivationCompute::Run() { ...@@ -132,24 +132,25 @@ void ElementwiseAddActivationCompute::Run() {
} }
} }
void ElementwiseSubCompute::Run() { template <typename T, PrecisionType PType>
auto& param = Param<operators::ElementwiseParam>(); void ElementwiseSubCompute<T, PType>::Run() {
const float* x_data = param.X->data<float>(); auto& param = this->template Param<operators::ElementwiseParam>();
const float* y_data = param.Y->data<float>(); const T* x_data = param.X->template data<T>();
float* out_data = param.Out->mutable_data<float>(); const T* y_data = param.Y->template data<T>();
T* out_data = param.Out->template mutable_data<T>();
int axis = param.axis; int axis = param.axis;
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (x_dims.size() < y_dims.size() && if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_broadcast( lite::arm::math::elementwise_sub_broadcast<T>(
y_data, x_data, out_data, pre, n, post); y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_broadcast( lite::arm::math::elementwise_sub_broadcast<T>(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
} else { } else {
lite::arm::math::elementwise_sub( lite::arm::math::elementwise_sub<T>(
x_data, y_data, out_data, x_dims.production()); x_data, y_data, out_data, x_dims.production());
} }
} }
...@@ -419,12 +420,10 @@ REGISTER_LITE_KERNEL( ...@@ -419,12 +420,10 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(elementwise_sub, using elementwise_sub_float_t =
kARM, paddle::lite::kernels::arm::ElementwiseSubCompute<float, PRECISION(kFloat)>;
kFloat, REGISTER_LITE_KERNEL(
kNCHW, elementwise_sub, kARM, kFloat, kNCHW, elementwise_sub_float_t, def)
paddle::lite::kernels::arm::ElementwiseSubCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
......
...@@ -38,8 +38,8 @@ class ElementwiseAddActivationCompute ...@@ -38,8 +38,8 @@ class ElementwiseAddActivationCompute
virtual ~ElementwiseAddActivationCompute() = default; virtual ~ElementwiseAddActivationCompute() = default;
}; };
class ElementwiseSubCompute template <typename T, PrecisionType PType>
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class ElementwiseSubCompute : public KernelLite<TARGET(kARM), PType> {
public: public:
void Run() override; void Run() override;
......
...@@ -27,7 +27,7 @@ using param_t = operators::ElementwiseParam; ...@@ -27,7 +27,7 @@ using param_t = operators::ElementwiseParam;
using grad_param_t = operators::ElementwiseGradParam; using grad_param_t = operators::ElementwiseGradParam;
using kernel_add_t = ElementwiseAddCompute<float, PRECISION(kFloat)>; using kernel_add_t = ElementwiseAddCompute<float, PRECISION(kFloat)>;
using grad_kernel_add_t = ElementwiseAddGradCompute; using grad_kernel_add_t = ElementwiseAddGradCompute;
using kernel_sub_t = ElementwiseSubCompute; using kernel_sub_t = ElementwiseSubCompute<float, PRECISION(kFloat)>;
using grad_kernel_sub_t = ElementwiseSubGradCompute; using grad_kernel_sub_t = ElementwiseSubGradCompute;
void elementwise_common(grad_param_t& param, // NOLINT void elementwise_common(grad_param_t& param, // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册