提交 8018225c 编写于 作者: Z zhangwen31

[arm][kernel]refactor: ElementwiseAddCompute in arm kernel is template now

上级 069e4c52
...@@ -71,24 +71,25 @@ inline bool is_broadcast(const DDim& x_dims, ...@@ -71,24 +71,25 @@ inline bool is_broadcast(const DDim& x_dims,
return true; return true;
} }
void ElementwiseAddCompute::Run() { template <typename T, PrecisionType PType>
auto& param = Param<operators::ElementwiseParam>(); void ElementwiseAddCompute<T, PType>::Run() {
const float* x_data = param.X->data<float>(); auto& param = this->template Param<operators::ElementwiseParam>();
const float* y_data = param.Y->data<float>(); const T* x_data = param.X->template data<T>();
float* out_data = param.Out->mutable_data<float>(); const T* y_data = param.Y->template data<T>();
T* out_data = param.Out->template mutable_data<T>();
int axis = param.axis; int axis = param.axis;
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (x_dims.size() < y_dims.size() && if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_broadcast( lite::arm::math::elementwise_add_broadcast<T>(
y_data, x_data, out_data, pre, n, post); y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_broadcast( lite::arm::math::elementwise_add_broadcast<T>(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
} else { } else {
lite::arm::math::elementwise_add( lite::arm::math::elementwise_add<T>(
x_data, y_data, out_data, x_dims.production()); x_data, y_data, out_data, x_dims.production());
} }
} }
...@@ -377,12 +378,10 @@ void ElementwiseModCompute<T, PType>::Run() { ...@@ -377,12 +378,10 @@ void ElementwiseModCompute<T, PType>::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(elementwise_add, using elementwise_add_float_t =
kARM, paddle::lite::kernels::arm::ElementwiseAddCompute<float, PRECISION(kFloat)>;
kFloat, REGISTER_LITE_KERNEL(
kNCHW, elementwise_add, kARM, kFloat, kNCHW, elementwise_add_float_t, def)
paddle::lite::kernels::arm::ElementwiseAddCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
......
...@@ -22,8 +22,8 @@ namespace lite { ...@@ -22,8 +22,8 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class ElementwiseAddCompute template <typename T, PrecisionType PType>
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class ElementwiseAddCompute : public KernelLite<TARGET(kARM), PType> {
public: public:
void Run() override; void Run() override;
......
...@@ -33,7 +33,7 @@ TEST(elementwise_add_arm, retrive_op) { ...@@ -33,7 +33,7 @@ TEST(elementwise_add_arm, retrive_op) {
} }
TEST(elementwise_add_arm, init) { TEST(elementwise_add_arm, init) {
ElementwiseAddCompute elementwise_add; ElementwiseAddCompute<float, PRECISION(kFloat)> elementwise_add;
ASSERT_EQ(elementwise_add.precision(), PRECISION(kFloat)); ASSERT_EQ(elementwise_add.precision(), PRECISION(kFloat));
ASSERT_EQ(elementwise_add.target(), TARGET(kARM)); ASSERT_EQ(elementwise_add.target(), TARGET(kARM));
} }
...@@ -255,7 +255,7 @@ template void elementwise_imod_compute_ref<int64_t>( ...@@ -255,7 +255,7 @@ template void elementwise_imod_compute_ref<int64_t>(
const operators::ElementwiseParam& param, const std::string act_type); const operators::ElementwiseParam& param, const std::string act_type);
TEST(elementwise_add, compute) { TEST(elementwise_add, compute) {
ElementwiseAddCompute elementwise_add; ElementwiseAddCompute<float, PRECISION(kFloat)> elementwise_add;
operators::ElementwiseParam param; operators::ElementwiseParam param;
lite::Tensor x, y, output, output_ref; lite::Tensor x, y, output, output_ref;
......
...@@ -25,7 +25,7 @@ namespace arm { ...@@ -25,7 +25,7 @@ namespace arm {
using param_t = operators::ElementwiseParam; using param_t = operators::ElementwiseParam;
using grad_param_t = operators::ElementwiseGradParam; using grad_param_t = operators::ElementwiseGradParam;
using kernel_add_t = ElementwiseAddCompute; using kernel_add_t = ElementwiseAddCompute<float, PRECISION(kFloat)>;
using grad_kernel_add_t = ElementwiseAddGradCompute; using grad_kernel_add_t = ElementwiseAddGradCompute;
using kernel_sub_t = ElementwiseSubCompute; using kernel_sub_t = ElementwiseSubCompute;
using grad_kernel_sub_t = ElementwiseSubGradCompute; using grad_kernel_sub_t = ElementwiseSubGradCompute;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册