提交 01c0868e 编写于 作者: H hong19860320 提交者: GitHub

enable softmax op and add unit test (#17703)

* enable softmax op and add unit test

* move softmax sub-functions to softmax.cc, and move basic math functions to funcs.h
上级 244a9e06
...@@ -72,6 +72,7 @@ USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); ...@@ -72,6 +72,7 @@ USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
// USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); // USE_LITE_KERNEL(feed, kARM, kAny, kAny, def);
// USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); // USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);
#endif // LITE_WITH_ARM #endif // LITE_WITH_ARM
......
cc_library(math_arm SRCS funcs.cc packed_sgemm.cc DEPS ${lite_kernel_deps} eigen3) cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc DEPS ${lite_kernel_deps} eigen3)
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void softmax_basic(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner8_axis4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner4_axis4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner8(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner1_large_axis(const T* din, T* dout, const int outer_size,
const int axis_size);
template <typename T>
void softmax_inner1_small_axis(const T* din, T* dout, const int outer_size,
const int axis_size);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -8,13 +8,16 @@ cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) ...@@ -8,13 +8,16 @@ cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
set(arm_kernels set(arm_kernels
fc_compute_arm fc_compute_arm
relu_compute_arm relu_compute_arm
mul_compute_arm mul_compute_arm
scale_compute_arm) scale_compute_arm
softmax_compute_arm)
set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/softmax_compute.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void SoftmaxCompute::Run() {
auto& param = Param<operators::SoftmaxParam>();
const float* din = param.x->data<float>();
float* dout = param.output->mutable_data<float>();
auto dim_x = param.x->dims();
auto rank_x = dim_x.size();
int axis = param.axis;
if (axis < 0) {
axis += rank_x;
}
int outer_num = dim_x.Slice(0, axis).production();
int inner_num = dim_x.Slice(axis + 1, rank_x).production();
int axis_size = dim_x[axis];
if (inner_num == 1) {
if (axis_size >= 4) {
lite::arm::math::softmax_inner1_large_axis(din, dout, outer_num,
axis_size);
} else {
lite::arm::math::softmax_inner1_small_axis(din, dout, outer_num,
axis_size);
}
} else {
int compute_size = outer_num * inner_num;
if (axis_size == 4 && inner_num % 8 == 0) {
lite::arm::math::softmax_inner8_axis4(din, dout, axis_size, inner_num,
outer_num);
} else if (axis_size == 4 && inner_num % 4 == 0) {
lite::arm::math::softmax_inner4_axis4(din, dout, axis_size, inner_num,
outer_num);
} else {
if (inner_num % 8 == 0) {
lite::arm::math::softmax_inner8(din, dout, axis_size, inner_num,
outer_num);
} else if (inner_num % 4 == 0) {
lite::arm::math::softmax_inner4(din, dout, axis_size, inner_num,
outer_num);
} else {
lite::arm::math::softmax_basic(din, dout, axis_size, inner_num,
outer_num);
}
}
}
}
TargetType SoftmaxCompute::target() const { return TARGET(kARM); }
PrecisionType SoftmaxCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(softmax, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::SoftmaxCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class SoftmaxCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
TargetType target() const override;
PrecisionType precision() const override;
virtual ~SoftmaxCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/softmax_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename dtype>
void softmat_compute_ref(const operators::SoftmaxParam& param) {
const dtype* x_data = param.x->mutable_data<const dtype>();
dtype* output_data = param.output->mutable_data<dtype>();
DDim dim = param.x->dims();
ASSERT_EQ(dim.data(), param.output->dims().data());
auto rank = dim.size();
int axis = param.axis;
if (axis < 0) {
axis += rank;
}
int axis_size = dim[axis];
int outer_num = dim.Slice(0, axis).production();
int inner_num = dim.Slice(axis + 1, rank).production();
int compute_size = outer_num * inner_num;
for (int i = 0; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int start = idx_outer * inner_num + idx_inner;
int offset;
offset = start;
dtype max_data = std::numeric_limits<dtype>::lowest();
for (int j = 0; j < axis_size; j++) {
max_data = x_data[offset] > max_data ? x_data[offset] : max_data;
offset += inner_num;
}
offset = start;
dtype sum_data = (dtype)0;
for (int j = 0; j < axis_size; j++) {
output_data[offset] = exp(x_data[offset] - max_data);
sum_data += output_data[offset];
offset += inner_num;
}
offset = start;
for (int j = 0; j < axis_size; j++) {
output_data[offset] /= sum_data;
offset += inner_num;
}
}
}
TEST(softmax_arm, init) {
SoftmaxCompute softmax;
ASSERT_EQ(softmax.precision(), PRECISION(kFloat));
ASSERT_EQ(softmax.target(), TARGET(kARM));
}
TEST(softmax_arm, compute) {
SoftmaxCompute softmax;
operators::SoftmaxParam param;
lite::Tensor x;
lite::Tensor output;
lite::Tensor output_ref;
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11, 4}) {
for (auto h : {3, 1, 11, 4}) {
for (auto w : {1, 3, 4, 12}) {
for (auto axis : {-4, -3, -2, -1, 0, 1, 2, 3}) {
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_data = x.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
}
param.x = &x;
param.axis = axis;
param.output = &output;
softmax.SetParam(param);
softmax.Run();
param.output = &output_ref;
softmat_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
}
TEST(softmax, retrive_op) {
auto softmax =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"softmax");
ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front());
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
...@@ -18,5 +18,6 @@ ...@@ -18,5 +18,6 @@
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); USE_LITE_KERNEL(feed, kARM, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);
...@@ -4,6 +4,7 @@ cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS}) ...@@ -4,6 +4,7 @@ cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS})
cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS}) cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS})
cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS}) cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS})
cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS})
cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS}) cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS})
cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
...@@ -19,6 +20,7 @@ set(ops_lite ...@@ -19,6 +20,7 @@ set(ops_lite
relu_op_lite relu_op_lite
mul_op_lite mul_op_lite
scale_op_lite scale_op_lite
softmax_op_lite
feed_op_lite feed_op_lite
fetch_op_lite fetch_op_lite
io_copy_op_lite io_copy_op_lite
...@@ -28,3 +30,4 @@ set(ops_lite ...@@ -28,3 +30,4 @@ set(ops_lite
PARENT_SCOPE) PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
...@@ -93,6 +93,14 @@ struct ScaleParam { ...@@ -93,6 +93,14 @@ struct ScaleParam {
bool bias_after_scale{true}; bool bias_after_scale{true};
}; };
// For Softmax Op
struct SoftmaxParam {
lite::Tensor* x{};
lite::Tensor* output{};
int axis{-1};
};
/// ----------------------- element wise operators ---------------------- /// ----------------------- element wise operators ----------------------
struct ElementwiseParam { struct ElementwiseParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/softmax_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SoftmaxOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
auto dim_x = param_.x->dims();
auto rank_x = dim_x.size();
CHECK_OR_FALSE(param_.axis >= -rank_x && param_.axis < rank_x);
return true;
}
bool SoftmaxOp::InferShape() const {
param_.output->Resize(param_.x->dims());
return true;
}
bool SoftmaxOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.axis = GetAttr<int>(opdesc.GetAttr("axis"));
CHECK(param_.x);
CHECK(param_.output);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SoftmaxOp : public OpLite {
public:
SoftmaxOp() {}
explicit SoftmaxOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "softmax"; }
private:
mutable SoftmaxParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/softmax_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(softmax_op_lite, test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* output = scope.Var("output")->GetMutable<Tensor>();
x->Resize(DDim(std::vector<int64_t>({10, 20})));
output->Resize(DDim(std::vector<int64_t>{10, 20}));
// set data
for (int i = 0; i < 10 * 20; i++) {
x->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
output->mutable_data<float>()[i] = 0.;
}
// prepare op desc
lite::OpDesc desc;
desc.SetType("softmax");
desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"output"});
desc.SetAttr("axis", static_cast<int>(-1));
SoftmaxOp softmax("softmax");
softmax.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
softmax.Attach(desc, &scope);
}
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册