diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 2f4abd0b644e9af10e9a88f126c339e9a3ad866e..7442a7be8b86b0440b17c5f70ff816a45864b9bd 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -34,6 +34,7 @@ add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_ add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling) add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) if(NOT LITE_WITH_X86) return() @@ -49,6 +50,7 @@ lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_ lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86) lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86) +lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86) diff --git a/lite/kernels/x86/batch_norm_compute.h b/lite/kernels/x86/batch_norm_compute.h index 3a94b99b171e684db9923fc7180195f136f4c414..9190a407dfad2eb565360f1bbf847bf33bca45e1 100644 --- a/lite/kernels/x86/batch_norm_compute.h +++ b/lite/kernels/x86/batch_norm_compute.h @@ -13,12 +13,14 @@ // limitations under the License. #pragma once +#include #include #include #include "lite/core/kernel.h" #include "lite/core/op_registry.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/operator.h" +#include "lite/core/types.h" +#include "lite/fluid/eigen.h" +#include "lite/operators/batch_norm_op.h" namespace paddle { namespace lite { @@ -42,6 +44,7 @@ class BatchNormCompute : public KernelLite { public: using param_t = operators::BatchNormParam; void Run() override { + // auto &context = ctx_->As(); auto ¶m = *param_.get_mutable(); bool global_stats = param.is_test || param.use_global_stats; @@ -55,12 +58,12 @@ class BatchNormCompute : public KernelLite { const int sample_size = x->dims().production() / N / C; // alloc memory - param.y->template mutable_data(); + param.y->mutable_data(); if (!param.is_test) { - param.mean_out->template mutable_data(); - param.variance_out->template mutable_data(); - param.saved_mean->template mutable_data(); - param.saved_variance->template mutable_data(); + param.mean_out->mutable_data(); + param.variance_out->mutable_data(); + param.saved_mean->mutable_data(); + param.saved_variance->mutable_data(); } if (!global_stats) { // saved_xx is use just in this batch of data @@ -79,8 +82,7 @@ class BatchNormCompute : public KernelLite { if ((N * sample_size) == 1) { LOG(WARNING) << "Only 1 element in normalization dimension, " << "we skip the batch norm calculation, let y = x."; - framework::TensorCopy( - x->raw_tensor(), platform::CPUPlace(), ¶m.y->raw_tensor()); + param.y->CopyDataFrom(*x); return; } diff --git a/lite/kernels/x86/batch_norm_compute_test.cc b/lite/kernels/x86/batch_norm_compute_test.cc index 254a6a7379e9ab18128020adfe18b206663b7877..e4c3268519345ad4eedbb079245f118979c90cbd 100644 --- a/lite/kernels/x86/batch_norm_compute_test.cc +++ b/lite/kernels/x86/batch_norm_compute_test.cc @@ -15,6 +15,8 @@ #include "lite/kernels/x86/batch_norm_compute.h" #include #include +#include +#include #include #include "lite/core/op_registry.h" @@ -116,6 +118,9 @@ TEST(batch_norm_x86, run_test) { param.saved_mean = &saved_mean; param.saved_variance = &saved_variance; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + batch_norm.SetContext(std::move(ctx)); batch_norm.SetParam(param); batch_norm.Run();