diff --git a/lite/kernels/x86/elementwise_compute.cc b/lite/kernels/x86/elementwise_compute.cc index 710e67956b055b84323a23443c671682704dd2c2..67b686aa32a9e9245ebfaf0971e3e3faa5945b52 100644 --- a/lite/kernels/x86/elementwise_compute.cc +++ b/lite/kernels/x86/elementwise_compute.cc @@ -35,3 +35,14 @@ REGISTER_LITE_KERNEL(elementwise_add, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_mul, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::ElementwiseMulCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/elementwise_compute.h b/lite/kernels/x86/elementwise_compute.h index c5598545f112e1d44739c6c88980f74875127836..a5afa255642f0c59ee774a0bd196c5181185f28e 100644 --- a/lite/kernels/x86/elementwise_compute.h +++ b/lite/kernels/x86/elementwise_compute.h @@ -33,6 +33,11 @@ struct AddFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } }; +template +struct MulFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a * b; } +}; + template class ElementwiseSubCompute : public KernelLite { @@ -71,6 +76,24 @@ class ElementwiseAddCompute virtual ~ElementwiseAddCompute() = default; }; +template +class ElementwiseMulCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = ctx_->As(); + param.Out->template mutable_data(); + paddle::lite::kernels::x86::ElementwiseComputeEx, + lite::TargetType::kX86, + T>( + context, param.X, param.Y, param.axis, MulFunctor(), param.Out); + } + + virtual ~ElementwiseMulCompute() = default; +}; + } // namespace x86 } // namespace kernels } // namespace lite diff --git a/lite/kernels/x86/elementwise_op_function.h b/lite/kernels/x86/elementwise_op_function.h index c49f21d1a8ee20db249274874e21accd00dfbcd1..a94944a7f2873a3e4651f21d7db797231f229aaa 100644 --- a/lite/kernels/x86/elementwise_op_function.h +++ b/lite/kernels/x86/elementwise_op_function.h @@ -324,7 +324,7 @@ void ElementwiseComputeEx(const lite::Context &ctx, } axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + PADDLE_ENFORCE(axis >= 0 && axis < static_cast(x_dims.size()), "Axis should be in range [0, x_dims)"); auto y_dims = trim_trailing_singular_dims(y_dims_untrimed); axis = (y_dims.size() == 0) ? x_dims.size() : axis;