提交 0fbcd4ea 编写于 作者: H hong19860320

add arm kernel and unit test for relue op

test=develop
上级 7c02e682
......@@ -32,6 +32,7 @@ cc_library(math_arm SRCS
conv_winograd_3x3.cc
conv_winograd.cc
split.cc
activation.cc
DEPS ${lite_kernel_deps} eigen3 framework_proto_lite)
# TODO(TJ): fix me do not deps proto
......
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void act_relu(const T* din, T* dout, int size, int threads);
template <typename T>
void act_relu_neg(const T* din, T* dout, int size, const float negative_slope,
int threads);
template <typename T>
void act_clipped_relu(const T* din, T* dout, int size, const float coef,
int threads);
template <typename T>
void act_prelu(const T* din, T* dout, int outer_size, int channel_size,
int inner_size, bool channel_shared, float* channel_slope,
int threads);
template <typename T>
void act_sigmoid(const T* din, T* dout, int size, int threads);
template <typename T>
void act_tanh(const T* din, T* dout, int size, int threads);
template <typename T>
void act_swish(const T* din, T* dout, int size, const float coef, int threads);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -5,7 +5,7 @@ endif()
message(STATUS "compile with lite ARM kernels")
cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(activation_compute_arm SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
......@@ -16,6 +16,7 @@ cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_a
cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_activation_compute_arm SRCS activation_compute_test.cc DEPS activation_compute_arm)
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm)
......@@ -27,7 +28,7 @@ lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_comput
set(arm_kernels
fc_compute_arm
relu_compute_arm
activation_compute_arm
mul_compute_arm
scale_compute_arm
softmax_compute_arm
......
......@@ -12,31 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/kernels/arm/activation_compute.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ReluCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override {
auto& param = Param<operators::ReluParam>();
auto n = param.input->dims().production();
const float* input = param.input->data<float>();
float* output = param.output->mutable_data<float>();
for (int i = 0; i < n; i++) {
output[i] = std::max(0.f, input[i]);
}
}
TargetType target() const override { return TARGET(kARM); }
PrecisionType precision() const override { return PRECISION(kFloat); }
};
void ReluCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
lite::arm::math::act_relu<float>(x_data, output_data, x_dims.production(),
ctx.threads());
}
} // namespace arm
} // namespace kernels
......
......@@ -12,4 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/relu_compute.h"
#pragma once
#include <algorithm>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ReluCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~ReluCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/activation_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename dtype>
void activation_compute_ref(const operators::ActivationParam& param) {
auto x_data = param.X->data<dtype>();
auto output_data = param.Out->mutable_data<dtype>();
DDim x_dims = param.X->dims();
DDim output_dims = param.Out->dims();
ASSERT_EQ(x_dims.data(), output_dims.data());
for (int i = 0; i < output_dims.production(); i++) {
output_data[i] = std::max(0.f, x_data[i]);
}
}
TEST(activation_arm, retrive_op) {
auto activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("relu");
ASSERT_FALSE(activation.empty());
ASSERT_TRUE(activation.front());
}
TEST(activation_arm, init) {
ReluCompute activation;
ASSERT_EQ(activation.precision(), PRECISION(kFloat));
ASSERT_EQ(activation.target(), TARGET(kARM));
}
TEST(activation_arm, compute) {
DeviceInfo::Init();
for (auto n : {1, 2}) {
for (auto c : {6, 32 /*, 128*/}) {
for (auto h : {9, 18 /*, 56 , 112, 224, 512*/}) {
for (auto w : {9, 18 /*, 56, 112, 224, 512*/}) {
Tensor x;
Tensor output;
Tensor output_ref;
// set the dims of input, output, ref output tensors
x.Resize({n, c, h, w});
output.Resize({n, c, h, w});
output_ref.Resize({n, c, h, w});
// initialize the data of input tensors
auto* x_data = x.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
x_data[i] = sign * static_cast<float>(i % 128) * 0.013f;
}
// prepare kernel params and run
ReluCompute activation;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
activation.SetContext(std::move(ctx));
operators::ActivationParam param;
param.X = &x;
param.Out = &output;
activation.SetParam(param);
activation.Launch();
// invoking ref implementation and compare results
param.Out = &output_ref;
activation_compute_ref<float>(param);
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册