提交 6006a87c 编写于 作者: L lijianshe02 提交者: GitHub

add x86 softmax kernel and fix jit compute bugs test=develop (#2007)

上级 192320c4
...@@ -4,9 +4,9 @@ function(USE_JITKERNEL_MORE TARGET TYPE) ...@@ -4,9 +4,9 @@ function(USE_JITKERNEL_MORE TARGET TYPE)
endfunction() endfunction()
# enable it latter # enable it latter
# if(WITH_MKLML) if(WITH_MKLML)
# add_subdirectory(mkl) add_subdirectory(mkl)
# endif() endif()
if(WITH_AVX) if(WITH_AVX)
add_subdirectory(intrinsic) add_subdirectory(intrinsic)
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "lite/backends/x86/cpu_info.h" #include "lite/backends/x86/cpu_info.h"
#include "lite/backends/x86/jit/helper.h" #include "lite/backends/x86/jit/helper.h"
#include "lite/backends/x86/jit/kernel_base.h" #include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h"
#include "lite/backends/x86/math/cpu_vec.h" #include "lite/backends/x86/math/cpu_vec.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/fluid/eigen.h" #include "lite/fluid/eigen.h"
......
...@@ -33,8 +33,10 @@ add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps ...@@ -33,8 +33,10 @@ add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps
add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps}) add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps}) add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling) add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling)
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86)
lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86) lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86)
lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86) lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86)
lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
...@@ -14,12 +14,9 @@ ...@@ -14,12 +14,9 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include "lite/backends/x86/math/softmax.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -55,7 +52,7 @@ class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -55,7 +52,7 @@ class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& param = *param_.get_mutable<operators::SoftmaxParam>(); auto& param = *param_.get_mutable<operators::SoftmaxParam>();
// auto& context = context_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
CHECK(param.output); CHECK(param.output);
CHECK(param.x); CHECK(param.x);
param.output->mutable_data<T>(); param.output->mutable_data<T>();
...@@ -72,13 +69,8 @@ class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -72,13 +69,8 @@ class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
out_2d.ShareDataWith(*param.output); out_2d.ShareDataWith(*param.output);
out_2d.Resize(lite::DDim(shape)); out_2d.Resize(lite::DDim(shape));
paddle::operators::math::SoftmaxFunctor<platform::CPUDeviceContext, lite::x86::math::SoftmaxFunctor<lite::TargetType::kX86, T, true>()(
T, context, axis_dim, &input_2d, &out_2d);
true>()(
platform::CPUDeviceContext(),
axis_dim,
&input_2d.raw_tensor(),
&out_2d.raw_tensor());
} }
virtual ~SoftmaxCompute() = default; virtual ~SoftmaxCompute() = default;
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
#include "lite/kernels/x86/softmax_compute.h" #include "lite/kernels/x86/softmax_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <memory>
#include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -54,15 +55,24 @@ TEST(softmax_x86, run_test) { ...@@ -54,15 +55,24 @@ TEST(softmax_x86, run_test) {
SoftmaxCompute<float> softmax; SoftmaxCompute<float> softmax;
operators::SoftmaxParam param; operators::SoftmaxParam param;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
softmax.SetContext(std::move(ctx));
param.x = &x; param.x = &x;
param.output = &out; param.output = &out;
softmax.SetParam(param); softmax.SetParam(param);
softmax.Run(); softmax.Run();
LOG(INFO) << "output: "; std::vector<float> ref_results = {
0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241,
0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241,
0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241,
0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241,
0.0900306, 0.244728, 0.665241};
for (int i = 0; i < out.dims().production(); i++) { for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i]; EXPECT_NEAR(out_data[i], ref_results[i], 1e-3);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册