提交 2c9ef4b7 编写于 作者: T tensor-tang

Merge branch 'zhupy/add-dropout-arm-kernel' into 'incubate/lite'

add dropout arm kernel and unit test

See merge request inference/paddlelite!15
......@@ -33,6 +33,7 @@ cc_library(math_arm SRCS
conv_winograd.cc
split.cc
activation.cc
dropout.cc
DEPS ${lite_kernel_deps} eigen3 framework_proto_lite)
# TODO(TJ): fix me do not deps proto
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/dropout.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void dropout_down<float>(const float* din, float* dout, int num, float prob) {
const float scale = 1.0f - prob;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vscale = vdupq_n_f32(scale);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* din_ptr = din + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
float32x4_t vmul0 = vmulq_f32(din0, vscale);
float32x4_t vmul1 = vmulq_f32(din1, vscale);
float32x4_t vmul2 = vmulq_f32(din2, vscale);
float32x4_t vmul3 = vmulq_f32(din3, vscale);
vst1q_f32(dout_ptr, vmul0);
vst1q_f32(dout_ptr + 4, vmul1);
vst1q_f32(dout_ptr + 8, vmul2);
vst1q_f32(dout_ptr + 12, vmul3);
}
if (remain > 0) {
const float* din_ptr = din + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr * scale;
dout_ptr++;
din_ptr++;
}
}
}
template <>
void dropout_up<float>(const float* din, float* dout, int num) {
int cnt = num >> 4;
int remain = num % 16;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* din_ptr = din + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
}
if (remain > 0) {
const float* din_ptr = din + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr;
dout_ptr++;
din_ptr++;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void dropout_down(const T* din, T* dout, int num, float prob);
template <typename T>
void dropout_up(const T* din, T* dout, int num);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -14,6 +14,7 @@ cc_library(batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS ${lite_kernel_
cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(dropout_compute_arm SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_activation_compute_arm SRCS activation_compute_test.cc DEPS activation_compute_arm)
......@@ -25,6 +26,7 @@ lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.
lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm)
lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm)
lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm)
lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm)
set(arm_kernels
fc_compute_arm
......@@ -37,6 +39,7 @@ set(arm_kernels
elementwise_add_compute_arm
pool_compute_arm
split_compute_arm
dropout_compute_arm
)
set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels")
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/dropout_compute.h"
#include <string>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void DropoutCompute::Run() {
auto& param = Param<operators::DropoutParam>();
const float* x_data = param.x->data<float>();
float* out_data = param.output->mutable_data<float>();
int num = param.x->dims().production();
const float prob_data = param.dropout_prob;
if (param.dropout_implementation.compare(std::string({"downgrade_in_infer"})))
lite::arm::math::dropout_down(x_data, out_data, num, prob_data);
else
lite::arm::math::dropout_up(x_data, out_data, num);
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(dropout, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::DropoutCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("dropout_prob", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("dropout_implementation", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class DropoutCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~DropoutCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/dropout_compute.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
TEST(dropout_arm, init) {
DropoutCompute dropout;
ASSERT_EQ(dropout.precision(), PRECISION(kFloat));
ASSERT_EQ(dropout.target(), TARGET(kARM));
}
TEST(dropout, retrive_op) {
auto dropout =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"dropout");
ASSERT_FALSE(dropout.empty());
ASSERT_TRUE(dropout.front());
}
template <typename dtype>
void dropout_compute_ref(const operators::DropoutParam& param) {
const float* x_data = param.x->data<float>();
float* output_data = param.output->mutable_data<float>();
int num = param.x->dims().production();
const float prob_data = param.dropout_prob;
if (param.dropout_implementation.compare(
std::string({"downgrade_in_infer"}))) {
float scale = 1.0 - prob_data;
for (int i = 0; i < num; i++) {
output_data[i] = x_data[i] * scale;
}
} else {
for (int i = 0; i < num; i++) {
output_data[i] = x_data[i];
}
}
}
TEST(dropout_arm, compute) {
DropoutCompute dropout;
operators::DropoutParam param;
lite::Tensor x;
lite::Tensor output;
lite::Tensor output_ref;
for (auto n : {1, 3, 4}) {
for (auto c : {1, 3, 4}) {
for (auto h : {1, 3, 4}) {
for (auto w : {1, 3, 4}) {
for (auto prob : {0.2f, 0.8f})
for (auto impl : {std::string({"downgrade_in_infer"}),
std::string({"upscale_in_train"})}) {
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_data = x.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
}
param.x = &x;
param.output = &output;
param.dropout_prob = prob;
param.dropout_implementation = impl;
dropout.SetParam(param);
dropout.Run();
param.output = &output_ref;
dropout_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(dropout, kARM, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册