未验证 提交 81132a32 编写于 作者: L lijianshe02 提交者: GitHub

add x86 softmax kernel and fix jit compute bugs test=develop (#2007)

上级 62ea82d0
......@@ -4,9 +4,9 @@ function(USE_JITKERNEL_MORE TARGET TYPE)
endfunction()
# enable it latter
# if(WITH_MKLML)
# add_subdirectory(mkl)
# endif()
if(WITH_MKLML)
add_subdirectory(mkl)
endif()
if(WITH_AVX)
add_subdirectory(intrinsic)
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "lite/backends/x86/cpu_info.h"
#include "lite/backends/x86/jit/helper.h"
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h"
#include "lite/backends/x86/math/cpu_vec.h"
#include "lite/core/tensor.h"
#include "lite/fluid/eigen.h"
......
......@@ -33,8 +33,10 @@ add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps
add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling)
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86)
lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86)
lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86)
lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
......@@ -14,12 +14,9 @@
#pragma once
#include <vector>
#include "lite/backends/x86/math/softmax.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -55,7 +52,7 @@ class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override {
auto& param = *param_.get_mutable<operators::SoftmaxParam>();
// auto& context = context_->As<X86Context>();
auto& context = ctx_->As<X86Context>();
CHECK(param.output);
CHECK(param.x);
param.output->mutable_data<T>();
......@@ -72,13 +69,8 @@ class SoftmaxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
out_2d.ShareDataWith(*param.output);
out_2d.Resize(lite::DDim(shape));
paddle::operators::math::SoftmaxFunctor<platform::CPUDeviceContext,
T,
true>()(
platform::CPUDeviceContext(),
axis_dim,
&input_2d.raw_tensor(),
&out_2d.raw_tensor());
lite::x86::math::SoftmaxFunctor<lite::TargetType::kX86, T, true>()(
context, axis_dim, &input_2d, &out_2d);
}
virtual ~SoftmaxCompute() = default;
......
......@@ -14,7 +14,8 @@
#include "lite/kernels/x86/softmax_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
......@@ -54,15 +55,24 @@ TEST(softmax_x86, run_test) {
SoftmaxCompute<float> softmax;
operators::SoftmaxParam param;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
softmax.SetContext(std::move(ctx));
param.x = &x;
param.output = &out;
softmax.SetParam(param);
softmax.Run();
LOG(INFO) << "output: ";
std::vector<float> ref_results = {
0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241,
0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241,
0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241,
0.0900306, 0.244728, 0.665241, 0.0900306, 0.244728, 0.665241,
0.0900306, 0.244728, 0.665241};
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
EXPECT_NEAR(out_data[i], ref_results[i], 1e-3);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册