提交 8018225c 编写于 作者: Z zhangwen31

[arm][kernel]refactor: ElementwiseAddCompute in arm kernel is template now

上级 069e4c52
......@@ -71,24 +71,25 @@ inline bool is_broadcast(const DDim& x_dims,
return true;
}
void ElementwiseAddCompute::Run() {
auto& param = Param<operators::ElementwiseParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
template <typename T, PrecisionType PType>
void ElementwiseAddCompute<T, PType>::Run() {
auto& param = this->template Param<operators::ElementwiseParam>();
const T* x_data = param.X->template data<T>();
const T* y_data = param.Y->template data<T>();
T* out_data = param.Out->template mutable_data<T>();
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_broadcast(
lite::arm::math::elementwise_add_broadcast<T>(
y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_broadcast(
lite::arm::math::elementwise_add_broadcast<T>(
x_data, y_data, out_data, pre, n, post);
} else {
lite::arm::math::elementwise_add(
lite::arm::math::elementwise_add<T>(
x_data, y_data, out_data, x_dims.production());
}
}
......@@ -377,12 +378,10 @@ void ElementwiseModCompute<T, PType>::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(elementwise_add,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseAddCompute,
def)
using elementwise_add_float_t =
paddle::lite::kernels::arm::ElementwiseAddCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(
elementwise_add, kARM, kFloat, kNCHW, elementwise_add_float_t, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
......
......@@ -22,8 +22,8 @@ namespace lite {
namespace kernels {
namespace arm {
class ElementwiseAddCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
template <typename T, PrecisionType PType>
class ElementwiseAddCompute : public KernelLite<TARGET(kARM), PType> {
public:
void Run() override;
......
......@@ -33,7 +33,7 @@ TEST(elementwise_add_arm, retrive_op) {
}
TEST(elementwise_add_arm, init) {
ElementwiseAddCompute elementwise_add;
ElementwiseAddCompute<float, PRECISION(kFloat)> elementwise_add;
ASSERT_EQ(elementwise_add.precision(), PRECISION(kFloat));
ASSERT_EQ(elementwise_add.target(), TARGET(kARM));
}
......@@ -255,7 +255,7 @@ template void elementwise_imod_compute_ref<int64_t>(
const operators::ElementwiseParam& param, const std::string act_type);
TEST(elementwise_add, compute) {
ElementwiseAddCompute elementwise_add;
ElementwiseAddCompute<float, PRECISION(kFloat)> elementwise_add;
operators::ElementwiseParam param;
lite::Tensor x, y, output, output_ref;
......
......@@ -25,7 +25,7 @@ namespace arm {
using param_t = operators::ElementwiseParam;
using grad_param_t = operators::ElementwiseGradParam;
using kernel_add_t = ElementwiseAddCompute;
using kernel_add_t = ElementwiseAddCompute<float, PRECISION(kFloat)>;
using grad_kernel_add_t = ElementwiseAddGradCompute;
using kernel_sub_t = ElementwiseSubCompute;
using grad_kernel_sub_t = ElementwiseSubGradCompute;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册