未验证 提交 fa8c8971 编写于 作者: L liu zhengxi 提交者: GitHub

Add layer_norm op on Lite x86 platform (#2463)

上级 134c138f
......@@ -19,6 +19,7 @@ add_kernel(pool_compute_x86 X86 basic SRCS pool_compute.cc DEPS ${lite_kernel_de
add_kernel(stack_compute_x86 X86 basic SRCS stack_compute.cc DEPS ${lite_kernel_deps})
add_kernel(dropout_compute_x86 X86 basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_x86 X86 basic SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_function)
add_kernel(layer_norm_compute_x86 X86 basic SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} jit_kernel_helper)
# add_kernel(fc_compute_x86 X86 basic SRCS fc_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps} )
......@@ -83,6 +84,7 @@ lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86)
lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86)
lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc DEPS cast_compute_x86)
lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
lite_cc_test(test_layer_norm_compute_x86 SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_x86)
lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86)
lite_cc_test(test_search_fc_compute_x86 SRCS search_fc_compute_test.cc DEPS search_fc_compute_x86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/layer_norm_compute.h"
REGISTER_LITE_KERNEL(layer_norm,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::LayerNormCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Mean", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Variance", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/x86/jit/helper.h"
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/layer_norm_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class LayerNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::LayerNormParam;
void Run() override {
auto &param = *param_.get_mutable<param_t>();
float epsilon = param.epsilon;
auto Scale = param.Scale;
auto Bias = param.Bias;
auto x = param.X;
auto y = param.Y;
auto Mean = param.Mean;
auto Var = param.Variance;
auto begin_norm_axis = param.begin_norm_axis;
auto x_dims = x->dims();
y->mutable_data<T>();
Mean->mutable_data<T>();
Var->mutable_data<T>();
auto matrix_dim = x_dims.Flatten2D(begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
lite::DDim matrix_shape({left, right});
lite::Tensor in;
in.ShareDataWith(*x);
in.Resize(matrix_shape);
lite::Tensor out;
out.ShareDataWith(*y);
out.Resize(matrix_shape);
PADDLE_ENFORCE_EQ(Mean->numel(), left);
PADDLE_ENFORCE_EQ(Var->numel(), left);
PADDLE_ENFORCE_EQ(Scale->numel(), right);
PADDLE_ENFORCE_EQ(Bias->numel(), right);
auto ker = paddle::lite::jit::KernelFuncs<jit::LayerNormTuple<T>,
lite::fluid::CPUPlace>::Cache()
.At(right);
ker(in.mutable_data<T>(),
out.mutable_data<T>(),
Mean->mutable_data<T>(),
Var->mutable_data<T>(),
Scale->data<T>(),
Bias->data<T>(),
static_cast<int>(left),
static_cast<const float>(epsilon),
right);
}
virtual ~LayerNormCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/layer_norm_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/backends/x86/jit/helper.h"
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
std::vector<float> ref(lite::Tensor* x,
lite::Tensor* Scale,
lite::Tensor* Bias,
lite::Tensor* y,
lite::Tensor* Mean,
lite::Tensor* Var,
int begin_norm_axis,
float epsilon) {
auto x_dims = x->dims();
y->mutable_data<float>();
Mean->mutable_data<float>();
Var->mutable_data<float>();
auto matrix_dim = x_dims.Flatten2D(begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
lite::DDim matrix_shape({left, right});
x->Resize(matrix_shape);
Tensor out;
out.ShareDataWith(*y);
out.Resize(matrix_shape);
auto ker = paddle::lite::jit::KernelFuncs<jit::LayerNormTuple<float>,
lite::fluid::CPUPlace>::Cache()
.At(right);
ker(x->mutable_data<float>(),
out.mutable_data<float>(),
Mean->mutable_data<float>(),
Var->mutable_data<float>(),
Scale->data<float>(),
Bias->data<float>(),
static_cast<int>(left),
static_cast<const float>(epsilon),
right);
std::vector<float> ref_data;
auto result = out.mutable_data<float>();
for (int i = 0; i < y->dims().production(); ++i) {
ref_data.emplace_back(result[i]);
}
return ref_data;
}
// layer_norm
TEST(layer_norm_x86, retrive_op) {
auto layer_norm =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"layer_norm");
ASSERT_FALSE(layer_norm.empty());
ASSERT_TRUE(layer_norm.front());
}
TEST(layer_norm_x86, init) {
lite::kernels::x86::LayerNormCompute<float> layer_norm;
ASSERT_EQ(layer_norm.precision(), PRECISION(kFloat));
ASSERT_EQ(layer_norm.target(), TARGET(kX86));
}
TEST(layer_norm_x86, run_test) {
lite::Tensor x;
lite::Tensor Scale;
lite::Tensor Bias;
lite::Tensor out;
lite::Tensor Mean;
lite::Tensor Var;
std::vector<int64_t> x_shape({1, 2, 3, 1});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({1, 2, 3, 1});
out.Resize(lite::DDim(out_shape));
int begin_norm_axis = 0;
float epsilon = 1e-5;
int pre = 1;
int post = 1;
for (int i = 0; i < begin_norm_axis; ++i) {
pre *= x_shape[i];
}
for (int i = begin_norm_axis; i < x_shape.size(); ++i) {
post *= x_shape[i];
}
std::vector<int64_t> scale_shape({post});
Scale.Resize(scale_shape);
std::vector<int64_t> bias_shape({post});
Bias.Resize(bias_shape);
auto x_data = x.mutable_data<float>();
auto scale_data = Scale.mutable_data<float>();
auto bias_data = Bias.mutable_data<float>();
auto out_data = out.mutable_data<float>();
auto mean_data = Mean.mutable_data<float>();
auto var_data = Var.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
for (int64_t i = 0; i < Scale.dims().production(); ++i) {
scale_data[i] = 1.5;
}
for (int64_t i = 0; i < Bias.dims().production(); ++i) {
bias_data[i] = 0.25;
}
LayerNormCompute<float> layer_norm;
operators::LayerNormParam param;
param.X = &x;
param.Y = &out;
param.Scale = &Scale;
param.Bias = &Bias;
param.Mean = &Mean;
param.Variance = &Var;
param.begin_norm_axis = begin_norm_axis;
param.epsilon = epsilon;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
layer_norm.SetContext(std::move(ctx));
layer_norm.SetParam(param);
layer_norm.Run();
std::vector<float> ref_data =
ref(&x, &Scale, &Bias, &out, &Mean, &Var, begin_norm_axis, epsilon);
for (int j = 0; j < out.dims().production(); ++j) {
EXPECT_NEAR(out_data[j], ref_data[j], 1e-5);
// LOG(INFO) << out_data[j];
}
LOG(INFO) << *mean_data;
LOG(INFO) << *var_data;
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(layer_norm, kX86, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册