未验证 提交 605a309c 编写于 作者: L liu zhengxi 提交者: GitHub

Add tanh op and gelu op for x86 platform (#2265)

* add tanh op in x86 platform and its unittest, test=develop

* add gelu op on x86 platform and add its unittests, test=develop

* update depends for math_function for activation for gelu, test=develop
上级 cec2730e
......@@ -32,8 +32,8 @@ math_library(sampler)
math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions)
lite_cc_library(blas SRCS blas.cc DEPS cblas framework_proto eigen3)
math_library(math_function DEPS blas)
lite_cc_library(blas SRCS blas.cc DEPS cblas framework_proto eigen3 dynload_mklml)
math_library(math_function DEPS blas dynload_mklml)
math_library(maxouting)
math_library(pooling)
math_library(selected_rows_functor DEPS selected_rows math_function blas)
......
add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_ops)
add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_ops math_function)
# lite_cc_library(mean_compute_x86 SRCS mean_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(fill_constant_compute_x86 SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps})
......@@ -55,6 +55,8 @@ lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS ba
lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86)
lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_tanh_compute_x86 SRCS tanh_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_gelu_compute_x86 SRCS gelu_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86)
lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86)
lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86)
......
......@@ -35,3 +35,25 @@ REGISTER_LITE_KERNEL(relu,
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// float
REGISTER_LITE_KERNEL(tanh,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::TanhCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// float
REGISTER_LITE_KERNEL(gelu,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::GeluCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -13,8 +13,10 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "lite/backends/x86/math/blas.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
......@@ -115,6 +117,76 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
virtual ~ReluCompute() = default;
};
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.tanh();
}
};
template <typename T>
class TanhCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override {
auto& param = *param_.get_mutable<operators::ActivationParam>();
param.Out->template mutable_data<T>();
Activate<TanhFunctor<T>>(param.X, param.Out);
}
virtual ~TanhCompute() = default;
};
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
template <typename T>
struct GeluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
// Because the execute or device context can not be deliver here, it keep the
// marco for NVCC.
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
auto x_data = x.data();
auto out_data = out.data();
int n = std::min(x.size(), out.size());
std::memset(out_data, 0, n * sizeof(T));
paddle::lite::x86::math::CBlas<T>::AXPY(
n, static_cast<T>(M_SQRT1_2), x_data, 1, out_data, 1);
paddle::lite::x86::math::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
for (int i = 0; i < n; i++) {
out_data[i] += static_cast<T>(1);
}
paddle::lite::x86::math::CBlas<T>::VMUL(n, x_data, out_data, out_data);
for (int i = 0; i < n; i++) {
out_data[i] *= static_cast<T>(0.5);
}
#else
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
#endif
}
};
template <typename T>
class GeluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override {
auto& param = *param_.get_mutable<operators::ActivationParam>();
param.Out->template mutable_data<T>();
Activate<GeluFunctor<T>>(param.X, param.Out);
}
virtual ~GeluCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(gelu_x86, retrive_op) {
auto gelu =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("gelu");
ASSERT_FALSE(gelu.empty());
ASSERT_TRUE(gelu.front());
}
TEST(gelu_x86, init) {
GeluCompute<float> gelu;
ASSERT_EQ(gelu.precision(), PRECISION(kFloat));
ASSERT_EQ(gelu.target(), TARGET(kX86));
}
TEST(gelu_x86, run_test) {
lite::Tensor x, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 3, 2, 2};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape{batch_size, 3, 2, 2};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
int sign = i % 2 == 0 ? 1 : -1;
x_data[i] = static_cast<float>(i * sign) * 0.8f;
}
// GeluCompute gelu;
GeluCompute<float> gelu;
operators::ActivationParam param;
param.X = &x;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
gelu.SetContext(std::move(ctx));
gelu.SetParam(param);
gelu.Run();
LOG(INFO) << "output: ";
std::vector<float> ref_data{0.,
-0.169484,
1.512321,
-0.019674,
3.197801,
-0.000126719,
4.8,
-0.,
6.4000001,
-0.,
8.,
-0.};
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(gelu, kX86, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(tanh_x86, retrive_op) {
auto tanh =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("tanh");
ASSERT_FALSE(tanh.empty());
ASSERT_TRUE(tanh.front());
}
TEST(tanh_x86, init) {
TanhCompute<float> tanh;
ASSERT_EQ(tanh.precision(), PRECISION(kFloat));
ASSERT_EQ(tanh.target(), TARGET(kX86));
}
TEST(tanh_x86, run_test) {
lite::Tensor x, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 3, 2, 2};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape{batch_size, 3, 2, 2};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); i++) {
int sign = i % 2 == 0 ? 1 : -1;
x_data[i] = static_cast<float>(i * sign) * 0.08f;
}
// TanhCompute tanh;
TanhCompute<float> tanh;
operators::ActivationParam param;
param.X = &x;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
tanh.SetContext(std::move(ctx));
tanh.SetParam(param);
tanh.Run();
LOG(INFO) << "output: ";
std::vector<float> ref_data{0.,
-0.079829,
0.158648,
-0.235495,
0.309506,
-0.379949,
0.446243,
-0.507977,
0.564899,
-0.616909,
0.664036,
-0.706419};
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
EXPECT_NEAR(out_data[i], ref_data[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(tanh, kX86, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册