未验证 提交 a3241ca7 编写于 作者: L liu zhengxi 提交者: GitHub

enable batch_norm op and add its unit tests, test=develop (#2201)

上级 75e8a6fc
......@@ -34,6 +34,7 @@ add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_
add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling)
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps})
add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
if(NOT LITE_WITH_X86)
return()
......@@ -49,6 +50,7 @@ lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_
lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86)
lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86)
lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86)
lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86)
lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86)
lite_cc_test(test_relu_compute_x86 SRCS relu_compute_test.cc DEPS activation_compute_x86)
......
......@@ -13,12 +13,14 @@
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <random>
#include <string>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "lite/core/types.h"
#include "lite/fluid/eigen.h"
#include "lite/operators/batch_norm_op.h"
namespace paddle {
namespace lite {
......@@ -42,6 +44,7 @@ class BatchNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::BatchNormParam;
void Run() override {
// auto &context = ctx_->As<X86Context>();
auto &param = *param_.get_mutable<operators::BatchNormParam>();
bool global_stats = param.is_test || param.use_global_stats;
......@@ -55,12 +58,12 @@ class BatchNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
const int sample_size = x->dims().production() / N / C;
// alloc memory
param.y->template mutable_data<T>();
param.y->mutable_data<T>();
if (!param.is_test) {
param.mean_out->template mutable_data<T>();
param.variance_out->template mutable_data<T>();
param.saved_mean->template mutable_data<T>();
param.saved_variance->template mutable_data<T>();
param.mean_out->mutable_data<T>();
param.variance_out->mutable_data<T>();
param.saved_mean->mutable_data<T>();
param.saved_variance->mutable_data<T>();
}
if (!global_stats) {
// saved_xx is use just in this batch of data
......@@ -79,8 +82,7 @@ class BatchNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
if ((N * sample_size) == 1) {
LOG(WARNING) << "Only 1 element in normalization dimension, "
<< "we skip the batch norm calculation, let y = x.";
framework::TensorCopy(
x->raw_tensor(), platform::CPUPlace(), &param.y->raw_tensor());
param.y->CopyDataFrom(*x);
return;
}
......
......@@ -15,6 +15,8 @@
#include "lite/kernels/x86/batch_norm_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
......@@ -116,6 +118,9 @@ TEST(batch_norm_x86, run_test) {
param.saved_mean = &saved_mean;
param.saved_variance = &saved_variance;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
batch_norm.SetContext(std::move(ctx));
batch_norm.SetParam(param);
batch_norm.Run();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册