From 62ea82d041b9caccc5817d81792edde35c528647 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 10 Sep 2019 16:51:09 +0800 Subject: [PATCH] add elementwise_sub and modify argmax (#1964) --- lite/api/_paddle_use_kernels.h | 1 + lite/api/_paddle_use_ops.h | 2 +- lite/backends/arm/math/elementwise.cc | 245 ++++++++++++++++++ lite/backends/arm/math/elementwise.h | 14 + lite/kernels/arm/argmax_compute.cc | 8 +- lite/kernels/arm/argmax_compute_test.cc | 4 +- lite/kernels/arm/elementwise_compute.cc | 72 ++++- lite/kernels/arm/elementwise_compute.h | 16 ++ lite/operators/argmax_op.cc | 4 +- lite/operators/op_params.h | 1 - lite/tests/kernels/argmax_compute_test.cc | 6 +- .../tests/kernels/elementwise_compute_test.cc | 122 ++++++++- 12 files changed, 479 insertions(+), 16 deletions(-) diff --git a/lite/api/_paddle_use_kernels.h b/lite/api/_paddle_use_kernels.h index c7c64bd933..50456e02cd 100644 --- a/lite/api/_paddle_use_kernels.h +++ b/lite/api/_paddle_use_kernels.h @@ -45,6 +45,7 @@ USE_LITE_KERNEL(box_coder, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_sub, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(elementwise_mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(elementwise_max, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(elementwise_div, kARM, kFloat, kNCHW, def); diff --git a/lite/api/_paddle_use_ops.h b/lite/api/_paddle_use_ops.h index df1aa9d35d..890c57c4aa 100644 --- a/lite/api/_paddle_use_ops.h +++ b/lite/api/_paddle_use_ops.h @@ -51,7 +51,7 @@ USE_LITE_OP(batch_norm) USE_LITE_OP(fusion_elementwise_sub_activation) USE_LITE_OP(transpose) USE_LITE_OP(transpose2) -USE_LITE_OP(argmax) +USE_LITE_OP(arg_max) USE_LITE_OP(axpy) USE_LITE_OP(leaky_relu) USE_LITE_OP(relu_clipped) diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index 28ae1ee4ca..a4c61f9a9d 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -266,6 +266,251 @@ void elementwise_add_relu_broadcast(const float* dinx, } } +template <> +void elementwise_sub(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vsubq_f32(dinx0, diny0); + dinx1 = vsubq_f32(dinx1, diny1); + dinx2 = vsubq_f32(dinx2, diny2); + dinx3 = vsubq_f32(dinx3, diny3); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *dinx_ptr - *diny_ptr; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +template <> +void elementwise_sub_relu(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vsubq_f32(dinx0, diny0); + dinx1 = vsubq_f32(dinx1, diny1); + dinx2 = vsubq_f32(dinx2, diny2); + dinx3 = vsubq_f32(dinx3, diny3); + + // relu + dinx0 = vmaxq_f32(dinx0, vzero); + dinx1 = vmaxq_f32(dinx1, vzero); + dinx2 = vmaxq_f32(dinx2, vzero); + dinx3 = vmaxq_f32(dinx3, vzero); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + float tmp = *dinx_ptr - *diny_ptr; + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +template <> +void elementwise_sub_broadcast(const float* dinx, + const float* diny, + float* dout, + int batch, + int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vsubq_f32(din0, rb); + din1 = vsubq_f32(din1, rb); + din2 = vsubq_f32(din2, rb); + din3 = vsubq_f32(din3, rb); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vsubq_f32(din0, rb); + din1 = vsubq_f32(din1, rb); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vsubq_f32(din0, rb); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; p++) { + *dout_ptr = *din_ptr - diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } +} + +template <> +void elementwise_sub_relu_broadcast(const float* dinx, + const float* diny, + float* dout, + int batch, + int channels, + int num) { + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vsubq_f32(din0, rb); + din1 = vsubq_f32(din1, rb); + din2 = vsubq_f32(din2, rb); + din3 = vsubq_f32(din3, rb); + + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + din2 = vmaxq_f32(din2, vzero); + din3 = vmaxq_f32(din3, vzero); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vsubq_f32(din0, rb); + din1 = vsubq_f32(din1, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vsubq_f32(din0, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; p++) { + float tmp = *din_ptr - diny_data; + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + din_ptr++; + } + } + } + } +} + template <> void elementwise_mul(const float* dinx, const float* diny, diff --git a/lite/backends/arm/math/elementwise.h b/lite/backends/arm/math/elementwise.h index 866277ae9c..f8273a5bb3 100644 --- a/lite/backends/arm/math/elementwise.h +++ b/lite/backends/arm/math/elementwise.h @@ -33,6 +33,20 @@ template void elementwise_add_relu_broadcast( const T* dinx, const T* diny, T* dout, int batch, int channels, int num); +template +void elementwise_sub(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_sub_relu(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_sub_broadcast( + const T* dinx, const T* diny, T* dout, int batch, int channels, int num); + +template +void elementwise_sub_relu_broadcast( + const T* dinx, const T* diny, T* dout, int batch, int channels, int num); + template void elementwise_mul(const T* dinx, const T* diny, T* dout, int num); diff --git a/lite/kernels/arm/argmax_compute.cc b/lite/kernels/arm/argmax_compute.cc index 5cb0e48c15..ad279e8f8e 100644 --- a/lite/kernels/arm/argmax_compute.cc +++ b/lite/kernels/arm/argmax_compute.cc @@ -40,8 +40,12 @@ void ArgmaxCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL( - argmax, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ArgmaxCompute, def) +REGISTER_LITE_KERNEL(arg_max, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ArgmaxCompute, + def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/argmax_compute_test.cc b/lite/kernels/arm/argmax_compute_test.cc index ee603efa86..58bdf18474 100644 --- a/lite/kernels/arm/argmax_compute_test.cc +++ b/lite/kernels/arm/argmax_compute_test.cc @@ -68,7 +68,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) { TEST(argmax_arm, retrive_op) { auto argmax = KernelRegistry::Global().Create( - "argmax"); + "arg_max"); ASSERT_FALSE(argmax.empty()); ASSERT_TRUE(argmax.front()); } @@ -136,4 +136,4 @@ TEST(argmax_arm, compute) { } // namespace kernels } // namespace lite } // namespace paddle -USE_LITE_KERNEL(argmax, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(arg_max, kARM, kFloat, kNCHW, def); diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index a0a87628ff..2e57b6a3b3 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -116,6 +116,51 @@ void ElementwiseAddActivationCompute::Run() { } } +void ElementwiseSubCompute::Run() { + auto& param = Param(); + const float* x_data = param.X->data(); + const float* y_data = param.Y->data(); + float* out_data = param.Out->mutable_data(); + int axis = param.axis; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int pre, n, post; + if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { + lite::arm::math::elementwise_sub_broadcast( + x_data, y_data, out_data, pre, n, post); + } else { + lite::arm::math::elementwise_sub( + x_data, y_data, out_data, x_dims.production()); + } +} + +void ElementwiseSubActivationCompute::Run() { + auto& param = Param(); + const float* x_data = param.X->data(); + const float* y_data = param.Y->data(); + float* out_data = param.Out->mutable_data(); + int axis = param.axis; + std::string act_type = param.act_type; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int pre, n, post; + if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { + if (act_type == "relu") { + lite::arm::math::elementwise_sub_relu_broadcast( + x_data, y_data, out_data, pre, n, post); + } else { + LOG(FATAL) << "unsupported Activation type: " << act_type; + } + } else { + if (act_type == "relu") { + lite::arm::math::elementwise_sub_relu( + x_data, y_data, out_data, x_dims.production()); + } else { + LOG(FATAL) << "unsupported Activation type: " << act_type; + } + } +} + void ElementwiseMulCompute::Run() { auto& param = Param(); const float* x_data = param.X->data(); @@ -249,10 +294,6 @@ void ElementwiseDivActivationCompute::Run() { LOG(FATAL) << "unsupported Activation type: " << act_type; } } - for (int i = 0; i < x_dims.production(); i++) { - LOG(INFO) << "x:" << x_data[i] << " y:" << y_data[i] - << " out:" << out_data[i]; - } } } // namespace arm @@ -283,6 +324,29 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +REGISTER_LITE_KERNEL(elementwise_sub, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ElementwiseSubCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + fusion_elementwise_sub_activation, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ElementwiseSubActivationCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kARM, kFloat, diff --git a/lite/kernels/arm/elementwise_compute.h b/lite/kernels/arm/elementwise_compute.h index 003f4d542f..e76449aebc 100644 --- a/lite/kernels/arm/elementwise_compute.h +++ b/lite/kernels/arm/elementwise_compute.h @@ -38,6 +38,22 @@ class ElementwiseAddActivationCompute virtual ~ElementwiseAddActivationCompute() = default; }; +class ElementwiseSubCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseSubCompute() = default; +}; + +class ElementwiseSubActivationCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseSubActivationCompute() = default; +}; + class ElementwiseMulCompute : public KernelLite { public: diff --git a/lite/operators/argmax_op.cc b/lite/operators/argmax_op.cc index ccfce32bb6..6b246603e1 100644 --- a/lite/operators/argmax_op.cc +++ b/lite/operators/argmax_op.cc @@ -50,7 +50,7 @@ bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { param_.X = scope->FindVar(x)->GetMutable(); param_.Out = scope->FindVar(out)->GetMutable(); - param_.Axis = op_desc.GetAttr("Axis"); + param_.Axis = op_desc.GetAttr("axis"); return true; } @@ -59,4 +59,4 @@ bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { } // namespace lite } // namespace paddle -REGISTER_LITE_OP(argmax, paddle::lite::operators::ArgmaxOpLite); +REGISTER_LITE_OP(arg_max, paddle::lite::operators::ArgmaxOpLite); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 90cbab1804..392ed6296a 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -761,7 +761,6 @@ struct GenerateProposalsParam { lite::Tensor* RpnRois{}; lite::Tensor* RpnRoiProbs{}; }; -/// ----------------------- shape operators ---------------------- /// ----------------------- squeeze operators ---------------------- struct SqueezeParam { const lite::Tensor* X{}; diff --git a/lite/tests/kernels/argmax_compute_test.cc b/lite/tests/kernels/argmax_compute_test.cc index 49cbd91071..9163e4bdaf 100644 --- a/lite/tests/kernels/argmax_compute_test.cc +++ b/lite/tests/kernels/argmax_compute_test.cc @@ -25,7 +25,7 @@ class ArgmaxComputeTester : public arena::TestCase { // common attributes for this op. std::string input_ = "x"; std::string output_ = "out"; - int axis_ = 0.; + int64_t axis_ = 0.; DDim dims_{{2, 5, 20, 30}}; public: @@ -82,10 +82,10 @@ class ArgmaxComputeTester : public arena::TestCase { } void PrepareOpDesc(cpp::OpDesc* op_desc) { - op_desc->SetType("argmax"); + op_desc->SetType("arg_max"); op_desc->SetInput("X", {input_}); op_desc->SetOutput("Out", {output_}); - op_desc->SetAttr("Axis", axis_); + op_desc->SetAttr("axis", axis_); } void PrepareData() override { diff --git a/lite/tests/kernels/elementwise_compute_test.cc b/lite/tests/kernels/elementwise_compute_test.cc index ceceb6394a..90f7d02362 100644 --- a/lite/tests/kernels/elementwise_compute_test.cc +++ b/lite/tests/kernels/elementwise_compute_test.cc @@ -71,6 +71,57 @@ class ElementwiseComputeTester : public arena::TestCase { } }; +class ElementwiseSubComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string inputx_ = "x"; + std::string inputy_ = "y"; + std::string output_ = "out"; + int axis_; + DDim dims_{{1, 2, 3, 4}}; + + public: + ElementwiseSubComputeTester(const Place& place, + const std::string& alias, + int axis) + : TestCase(place, alias), axis_(axis) {} + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(output_); + CHECK(out); + out->Resize(dims_); + auto* out_data = out->mutable_data(); + + auto* x = scope->FindTensor(inputx_); + const auto* x_data = x->data(); + auto* y = scope->FindTensor(inputy_); + const auto* y_data = x->data(); + + for (int i = 0; i < dims_.production(); i++) { + out_data[i] = x_data[i] - y_data[i]; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("elementwise_sub"); + op_desc->SetInput("X", {inputx_}); + op_desc->SetInput("Y", {inputy_}); + op_desc->SetOutput("Out", {output_}); + op_desc->SetAttr("axis", axis_); + } + + void PrepareData() override { + std::vector data(dims_.production()); + + for (int i = 0; i < dims_.production(); i++) { + data[i] = i * 1.1; + } + + SetCommonTensor(inputx_, dims_, data.data()); + SetCommonTensor(inputy_, dims_, data.data()); + } +}; + class ElementwiseMulComputeTester : public arena::TestCase { protected: // common attributes for this op. @@ -232,6 +283,65 @@ class FusionElementwiseAddActivationComputeTester : public arena::TestCase { } }; +class FusionElementwiseSubActivationComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string inputx_ = "x"; + std::string inputy_ = "y"; + std::string output_ = "out"; + int axis_; + std::string act_type_; + DDim dims_{{1, 2, 3, 4}}; + + public: + FusionElementwiseSubActivationComputeTester(const Place& place, + const std::string& alias, + int axis, + std::string act_type) + : TestCase(place, alias), axis_(axis), act_type_(act_type) {} + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(output_); + CHECK(out); + out->Resize(dims_); + auto* out_data = out->mutable_data(); + + auto* x = scope->FindTensor(inputx_); + const auto* x_data = x->data(); + auto* y = scope->FindTensor(inputy_); + const auto* y_data = x->data(); + + for (int i = 0; i < dims_.production(); i++) { + out_data[i] = x_data[i] - y_data[i]; + if (act_type_ == "relu") { + out_data[i] = out_data[i] > 0 ? out_data[i] : 0; + } else { + LOG(FATAL) << "unsupported Activation type: " << act_type_; + } + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("fusion_elementwise_sub_activation"); + op_desc->SetInput("X", {inputx_}); + op_desc->SetInput("Y", {inputy_}); + op_desc->SetOutput("Out", {output_}); + op_desc->SetAttr("axis", axis_); + op_desc->SetAttr("act_type", act_type_); + } + + void PrepareData() override { + std::vector data(dims_.production()); + + for (int i = 0; i < dims_.production(); i++) { + data[i] = i * 1.1; + } + + SetCommonTensor(inputx_, dims_, data.data()); + SetCommonTensor(inputy_, dims_, data.data()); + } +}; + class FusionElementwiseMulActivationComputeTester : public arena::TestCase { protected: // common attributes for this op. @@ -441,7 +551,6 @@ class FusionElementwiseDivActivationComputeTester : public arena::TestCase { } else { LOG(FATAL) << "unsupported Activation type: " << act_type_; } - LOG(INFO) << "fusion div resul:" << out_data[i]; } } @@ -476,6 +585,11 @@ void test_elementwise(Place place) { arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); + std::unique_ptr tester_sub( + new ElementwiseSubComputeTester(place, "def", axis)); + arena::Arena arena_sub(std::move(tester_sub), place, 2e-5); + arena_sub.TestPrecision(); + std::unique_ptr tester_mul( new ElementwiseMulComputeTester(place, "def", axis)); arena::Arena arena_mul(std::move(tester_mul), place, 2e-5); @@ -511,6 +625,12 @@ void test_fusion_elementwise(Place place) { arena::Arena arena_add_act(std::move(tester_add_act), place, 2e-5); arena_add_act.TestPrecision(); + std::unique_ptr tester_sub_act( + new FusionElementwiseSubActivationComputeTester( + place, "def", axis, "relu")); + arena::Arena arena_sub_act(std::move(tester_sub_act), place, 2e-5); + arena_sub_act.TestPrecision(); + std::unique_ptr tester_mul_act( new FusionElementwiseMulActivationComputeTester( place, "def", axis, "relu")); -- GitLab