From 605a309c77b0e4d42f6d8608d65692b2becf3d4e Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Tue, 29 Oct 2019 20:04:17 +0800 Subject: [PATCH] Add tanh op and gelu op for x86 platform (#2265) * add tanh op in x86 platform and its unittest, test=develop * add gelu op on x86 platform and add its unittests, test=develop * update depends for math_function for activation for gelu, test=develop --- lite/backends/x86/math/CMakeLists.txt | 4 +- lite/kernels/x86/CMakeLists.txt | 4 +- lite/kernels/x86/activation_compute.cc | 22 ++++++ lite/kernels/x86/activation_compute.h | 72 ++++++++++++++++++++ lite/kernels/x86/gelu_compute_test.cc | 92 ++++++++++++++++++++++++++ lite/kernels/x86/tanh_compute_test.cc | 92 ++++++++++++++++++++++++++ 6 files changed, 283 insertions(+), 3 deletions(-) create mode 100644 lite/kernels/x86/gelu_compute_test.cc create mode 100644 lite/kernels/x86/tanh_compute_test.cc diff --git a/lite/backends/x86/math/CMakeLists.txt b/lite/backends/x86/math/CMakeLists.txt index 5cc4a9f077..2dea4364d5 100644 --- a/lite/backends/x86/math/CMakeLists.txt +++ b/lite/backends/x86/math/CMakeLists.txt @@ -32,8 +32,8 @@ math_library(sampler) math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) -lite_cc_library(blas SRCS blas.cc DEPS cblas framework_proto eigen3) -math_library(math_function DEPS blas) +lite_cc_library(blas SRCS blas.cc DEPS cblas framework_proto eigen3 dynload_mklml) +math_library(math_function DEPS blas dynload_mklml) math_library(maxouting) math_library(pooling) math_library(selected_rows_functor DEPS selected_rows math_function blas) diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 60219e3b18..6d47c880c8 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -1,4 +1,4 @@ -add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_ops) +add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_ops math_function) # lite_cc_library(mean_compute_x86 SRCS mean_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(fill_constant_compute_x86 SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps}) @@ -55,6 +55,8 @@ lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS ba lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86) +lite_cc_test(test_tanh_compute_x86 SRCS tanh_compute_test.cc DEPS activation_compute_x86) +lite_cc_test(test_gelu_compute_x86 SRCS gelu_compute_test.cc DEPS activation_compute_x86) lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86) lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86) lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86) diff --git a/lite/kernels/x86/activation_compute.cc b/lite/kernels/x86/activation_compute.cc index 0ed09c43a5..b4a053419c 100644 --- a/lite/kernels/x86/activation_compute.cc +++ b/lite/kernels/x86/activation_compute.cc @@ -35,3 +35,25 @@ REGISTER_LITE_KERNEL(relu, .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); + +// float +REGISTER_LITE_KERNEL(tanh, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::TanhCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +// float +REGISTER_LITE_KERNEL(gelu, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::GeluCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/activation_compute.h b/lite/kernels/x86/activation_compute.h index 2775240194..482684b067 100644 --- a/lite/kernels/x86/activation_compute.h +++ b/lite/kernels/x86/activation_compute.h @@ -13,8 +13,10 @@ // limitations under the License. #pragma once +#include #include #include +#include "lite/backends/x86/math/blas.h" #include "lite/core/kernel.h" #include "lite/core/op_lite.h" #include "lite/core/op_registry.h" @@ -115,6 +117,76 @@ class ReluCompute : public KernelLite { virtual ~ReluCompute() = default; }; +// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +template +struct TanhFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.tanh(); + } +}; + +template +class TanhCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + param.Out->template mutable_data(); + Activate>(param.X, param.Out); + } + + virtual ~TanhCompute() = default; +}; + +// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) +template +struct GeluFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { +// Because the execute or device context can not be deliver here, it keep the +// marco for NVCC. +#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ + !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + auto x_data = x.data(); + auto out_data = out.data(); + int n = std::min(x.size(), out.size()); + + std::memset(out_data, 0, n * sizeof(T)); + paddle::lite::x86::math::CBlas::AXPY( + n, static_cast(M_SQRT1_2), x_data, 1, out_data, 1); + paddle::lite::x86::math::CBlas::VMERF(n, out_data, out_data, VML_LA); + for (int i = 0; i < n; i++) { + out_data[i] += static_cast(1); + } + paddle::lite::x86::math::CBlas::VMUL(n, x_data, out_data, out_data); + for (int i = 0; i < n; i++) { + out_data[i] *= static_cast(0.5); + } +#else + auto temp = (x * static_cast(M_SQRT1_2)).erf(); + out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); +#endif + } +}; + +template +class GeluCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + auto& param = *param_.get_mutable(); + + param.Out->template mutable_data(); + Activate>(param.X, param.Out); + } + + virtual ~GeluCompute() = default; +}; + } // namespace x86 } // namespace kernels } // namespace lite diff --git a/lite/kernels/x86/gelu_compute_test.cc b/lite/kernels/x86/gelu_compute_test.cc new file mode 100644 index 0000000000..20479760e9 --- /dev/null +++ b/lite/kernels/x86/gelu_compute_test.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/x86/activation_compute.cc" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(gelu_x86, retrive_op) { + auto gelu = + KernelRegistry::Global().Create("gelu"); + ASSERT_FALSE(gelu.empty()); + ASSERT_TRUE(gelu.front()); +} + +TEST(gelu_x86, init) { + GeluCompute gelu; + ASSERT_EQ(gelu.precision(), PRECISION(kFloat)); + ASSERT_EQ(gelu.target(), TARGET(kX86)); +} + +TEST(gelu_x86, run_test) { + lite::Tensor x, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 3, 2, 2}; + x.Resize(lite::DDim(x_shape)); + std::vector out_shape{batch_size, 3, 2, 2}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + int sign = i % 2 == 0 ? 1 : -1; + x_data[i] = static_cast(i * sign) * 0.8f; + } + // GeluCompute gelu; + GeluCompute gelu; + operators::ActivationParam param; + + param.X = &x; + param.Out = &out; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + gelu.SetContext(std::move(ctx)); + gelu.SetParam(param); + gelu.Run(); + + LOG(INFO) << "output: "; + std::vector ref_data{0., + -0.169484, + 1.512321, + -0.019674, + 3.197801, + -0.000126719, + 4.8, + -0., + 6.4000001, + -0., + 8., + -0.}; + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(gelu, kX86, kFloat, kNCHW, def); diff --git a/lite/kernels/x86/tanh_compute_test.cc b/lite/kernels/x86/tanh_compute_test.cc new file mode 100644 index 0000000000..fa65ca02df --- /dev/null +++ b/lite/kernels/x86/tanh_compute_test.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/x86/activation_compute.cc" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(tanh_x86, retrive_op) { + auto tanh = + KernelRegistry::Global().Create("tanh"); + ASSERT_FALSE(tanh.empty()); + ASSERT_TRUE(tanh.front()); +} + +TEST(tanh_x86, init) { + TanhCompute tanh; + ASSERT_EQ(tanh.precision(), PRECISION(kFloat)); + ASSERT_EQ(tanh.target(), TARGET(kX86)); +} + +TEST(tanh_x86, run_test) { + lite::Tensor x, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 3, 2, 2}; + x.Resize(lite::DDim(x_shape)); + std::vector out_shape{batch_size, 3, 2, 2}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + int sign = i % 2 == 0 ? 1 : -1; + x_data[i] = static_cast(i * sign) * 0.08f; + } + // TanhCompute tanh; + TanhCompute tanh; + operators::ActivationParam param; + + param.X = &x; + param.Out = &out; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + tanh.SetContext(std::move(ctx)); + tanh.SetParam(param); + tanh.Run(); + + LOG(INFO) << "output: "; + std::vector ref_data{0., + -0.079829, + 0.158648, + -0.235495, + 0.309506, + -0.379949, + 0.446243, + -0.507977, + 0.564899, + -0.616909, + 0.664036, + -0.706419}; + for (int i = 0; i < out.dims().production(); i++) { + LOG(INFO) << out_data[i]; + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(tanh, kX86, kFloat, kNCHW, def); -- GitLab