提交 95372548 编写于 作者: L liu zhengxi 提交者: GitHub

enable batch_norm op and add its unit tests, test=develop (#2201)

上级 ad541652
...@@ -34,6 +34,7 @@ add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_ ...@@ -34,6 +34,7 @@ add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_
add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling) add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling)
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps})
add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
if(NOT LITE_WITH_X86) if(NOT LITE_WITH_X86)
return() return()
...@@ -49,6 +50,7 @@ lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_ ...@@ -49,6 +50,7 @@ lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_
lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86)
lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86) lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86)
lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86) lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86)
lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86)
lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86) lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86)
lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86) lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86)
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <Eigen/Core>
#include <random> #include <random>
#include <string> #include <string>
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "paddle/fluid/framework/eigen.h" #include "lite/core/types.h"
#include "paddle/fluid/framework/operator.h" #include "lite/fluid/eigen.h"
#include "lite/operators/batch_norm_op.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -42,6 +44,7 @@ class BatchNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -42,6 +44,7 @@ class BatchNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public: public:
using param_t = operators::BatchNormParam; using param_t = operators::BatchNormParam;
void Run() override { void Run() override {
// auto &context = ctx_->As<X86Context>();
auto &param = *param_.get_mutable<operators::BatchNormParam>(); auto &param = *param_.get_mutable<operators::BatchNormParam>();
bool global_stats = param.is_test || param.use_global_stats; bool global_stats = param.is_test || param.use_global_stats;
...@@ -55,12 +58,12 @@ class BatchNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -55,12 +58,12 @@ class BatchNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
const int sample_size = x->dims().production() / N / C; const int sample_size = x->dims().production() / N / C;
// alloc memory // alloc memory
param.y->template mutable_data<T>(); param.y->mutable_data<T>();
if (!param.is_test) { if (!param.is_test) {
param.mean_out->template mutable_data<T>(); param.mean_out->mutable_data<T>();
param.variance_out->template mutable_data<T>(); param.variance_out->mutable_data<T>();
param.saved_mean->template mutable_data<T>(); param.saved_mean->mutable_data<T>();
param.saved_variance->template mutable_data<T>(); param.saved_variance->mutable_data<T>();
} }
if (!global_stats) { if (!global_stats) {
// saved_xx is use just in this batch of data // saved_xx is use just in this batch of data
...@@ -79,8 +82,7 @@ class BatchNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -79,8 +82,7 @@ class BatchNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
if ((N * sample_size) == 1) { if ((N * sample_size) == 1) {
LOG(WARNING) << "Only 1 element in normalization dimension, " LOG(WARNING) << "Only 1 element in normalization dimension, "
<< "we skip the batch norm calculation, let y = x."; << "we skip the batch norm calculation, let y = x.";
framework::TensorCopy( param.y->CopyDataFrom(*x);
x->raw_tensor(), platform::CPUPlace(), &param.y->raw_tensor());
return; return;
} }
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "lite/kernels/x86/batch_norm_compute.h" #include "lite/kernels/x86/batch_norm_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory>
#include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -116,6 +118,9 @@ TEST(batch_norm_x86, run_test) { ...@@ -116,6 +118,9 @@ TEST(batch_norm_x86, run_test) {
param.saved_mean = &saved_mean; param.saved_mean = &saved_mean;
param.saved_variance = &saved_variance; param.saved_variance = &saved_variance;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
batch_norm.SetContext(std::move(ctx));
batch_norm.SetParam(param); batch_norm.SetParam(param);
batch_norm.Run(); batch_norm.Run();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册