diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 412fc72b56f65b99e7d3f29915c99cd151a1790e..2f4abd0b644e9af10e9a88f126c339e9a3ad866e 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -10,6 +10,9 @@ add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_ add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps}) add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function) add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) +add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) +# lite_cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op) +# lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) # lite_cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} ) # lite_cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) # lite_cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling) @@ -37,6 +40,7 @@ if(NOT LITE_WITH_X86) endif() add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} blas) +lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86) lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86) diff --git a/lite/kernels/x86/conv_compute.h b/lite/kernels/x86/conv_compute.h index 39114e1716a0f1830a739fc034a0845a36c35702..48cb3c74ef3c05675115ab7cec09f16322d1410a 100644 --- a/lite/kernels/x86/conv_compute.h +++ b/lite/kernels/x86/conv_compute.h @@ -16,15 +16,14 @@ #include #include #include +#include "lite/backends/x86/math/blas.h" +#include "lite/backends/x86/math/im2col.h" +#include "lite/backends/x86/math/vol2col.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" #include "lite/core/types.h" +#include "lite/fluid/eigen.h" #include "lite/operators/conv_op.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/depthwise_conv.h" -#include "paddle/fluid/operators/math/im2col.h" -#include "paddle/fluid/operators/math/vol2col.h" namespace paddle { namespace lite { @@ -50,15 +49,14 @@ class Conv2dCompute : public KernelLite { public: using param_t = operators::ConvParam; void Run() override { + auto& context = ctx_->As(); auto& param = *param_.get_mutable(); lite::Tensor filter = *param.filter; - param.output->template mutable_data(); - + param.output->mutable_data(); const int batch_size = static_cast(param.x->dims()[0]); std::vector filter_shape_vec(filter.dims().Vectorize()); std::vector output_shape_vec(param.output->dims().Vectorize()); - size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); col_shape_vec[0] = param.x->dims()[1] / param.groups; @@ -70,7 +68,6 @@ class Conv2dCompute : public KernelLite { lite::DDim col_matrix_shape = col_shape.Flatten2D(data_dim + 1); bool is_expand = IsExpand( filter_shape_vec, param.strides, param.paddings, param.dilations); - lite::Tensor col; lite::Tensor col_matrix; if (is_expand) { @@ -80,40 +77,37 @@ class Conv2dCompute : public KernelLite { col_matrix.Resize(col_matrix_shape); } lite::DDim input_shape = param.x->dims().Slice(1, param.x->dims().size()); - lite::DDim filter_matrix_shape(std::vector{ filter.dims()[0], filter.dims().production() / filter.dims()[0]}); filter.Resize(filter_matrix_shape); - lite::DDim output_matrix_shape(std::vector{ param.output->dims()[1], param.output->dims().production() / (param.output->dims()[0] * param.output->dims()[1])}); - int in_step = static_cast(param.x->dims()[1]) / param.groups; int out_step = static_cast(param.output->dims()[1]) / param.groups; - - paddle::operators::math::Vol2ColFunctor - vol2col; - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, - platform::CPUDeviceContext, + paddle::lite::x86::math::Vol2ColFunctor vol2col; + paddle::lite::x86::math::Im2ColFunctor< + paddle::lite::x86::math::ColFormat::kCFO, + lite::TargetType::kX86, T> im2col; - auto blas = paddle::operators::math::GetBlas( - platform::CPUDeviceContext()); + auto blas = + paddle::lite::x86::math::GetBlas(context); for (int i = 0; i < batch_size; i++) { lite::Tensor in_batch; - in_batch.ShareDataWith( - param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data())); + lite::Tensor tmp_in_batch = param.x->Slice(i, i + 1); + tmp_in_batch.Resize(input_shape); + in_batch.ShareDataWith(tmp_in_batch); lite::Tensor out_batch; - out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize( - output_matrix_shape.data())); - + lite::Tensor tmp_out_batch = param.output->Slice(i, i + 1); + tmp_out_batch.Resize(output_matrix_shape); + out_batch.ShareDataWith(tmp_out_batch); for (int g = 0; g < param.groups; g++) { lite::Tensor in_slice; in_slice.ShareDataWith( - in_batch.raw_tensor().Slice(g * in_step, (g + 1) * in_step)); + in_batch.Slice(static_cast(g * in_step), + static_cast((g + 1) * in_step))); if (!is_expand) { col.ShareDataWith(in_slice); @@ -121,38 +115,40 @@ class Conv2dCompute : public KernelLite { col_matrix.Resize(col_matrix_shape); } else if (data_dim == 2U) { // im2col - im2col(platform::CPUDeviceContext(), - in_slice.raw_tensor(), + im2col(context, + in_slice, param.dilations, param.strides, std::vector{param.paddings[0], param.paddings[1], param.paddings[0], param.paddings[1]}, - &(col.raw_tensor())); + &(col)); } else if (data_dim == 3U) { // vol2col - vol2col(platform::CPUDeviceContext(), - in_slice.raw_tensor(), + vol2col(context, + in_slice, param.dilations, param.strides, param.paddings, - &(col.raw_tensor())); + &(col)); } // gemm lite::Tensor out_slice; out_slice.ShareDataWith( - out_batch.raw_tensor().Slice(g * out_step, (g + 1) * out_step)); + out_batch.Slice(static_cast(g * out_step), + static_cast((g + 1) * out_step))); lite::Tensor filter_slice; filter_slice.ShareDataWith( - filter.raw_tensor().Slice(g * out_step, (g + 1) * out_step)); - blas.MatMul(filter_slice.raw_tensor(), + filter.Slice(static_cast(g * out_step), + static_cast((g + 1) * out_step))); + blas.MatMul(filter_slice, false, - col_matrix.raw_tensor(), + col_matrix, false, T(1.0), - &(out_slice.raw_tensor()), + &(out_slice), T(0.0)); } } diff --git a/lite/kernels/x86/conv_compute_test.cc b/lite/kernels/x86/conv_compute_test.cc index 17efae41601925e217067ce07677bfc10da75bc9..d784018446e1478a3530430e1639a3fcf7e9c86a 100644 --- a/lite/kernels/x86/conv_compute_test.cc +++ b/lite/kernels/x86/conv_compute_test.cc @@ -14,6 +14,8 @@ #include "lite/kernels/x86/conv_compute.h" #include +#include +#include #include #include "lite/core/op_registry.h" @@ -38,7 +40,7 @@ TEST(conv2d_x86, init) { TEST(conv2d_x86, run_test) { lite::Tensor x, filter, b, out; - constexpr int batch_size = 1; + const int batch_size = 1; std::vector x_shape{batch_size, 3, 3, 3}; x.Resize(lite::DDim(x_shape)); std::vector filter_shape{1, 3, 3, 3}; @@ -74,7 +76,10 @@ TEST(conv2d_x86, run_test) { param.paddings = {0, 0}; param.groups = 1; param.dilations = {1, 1}; - + LOG(INFO) << 123; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + conv2d.SetContext(std::move(ctx)); conv2d.SetParam(param); conv2d.Run();