提交 d020229c 编写于 作者: H Hong Ming 提交者: Tensor Tang

enable conv_winograd, fix conv_gemmlike bug, and update the unit tests of conv op

test=develop
上级 5f0d7166
......@@ -58,6 +58,111 @@ void scale<float>(const float* din, float* dout, int num, float scale,
}
}
template <>
void scale<float>(const float* din, float* dout, int outer_dim, int scale_dim,
int inner_dim, const float* scale_data,
const float* bias_data) {
int cnt = inner_dim >> 4;
int remain = inner_dim % 16;
int size = inner_dim * scale_dim;
for (int n = 0; n < outer_dim; n++) {
const float* din_ptr_n = din + n * size;
float* dout_ptr_n = dout + n * size;
#pragma omp parallel for
for (int i = 0; i < scale_dim; i++) {
const float* din_ptr = din_ptr_n + i * inner_dim;
float* dout_ptr = dout_ptr_n + i * inner_dim;
float scale = scale_data[i];
float32x4_t vscale = vdupq_n_f32(scale);
float bias = bias_data[i];
float32x4_t vbias = vdupq_n_f32(bias);
for (int j = 0; j < cnt; j++) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale);
float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale);
float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale);
float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale);
din_ptr += 16;
vst1q_f32(dout_ptr, vsum1);
vst1q_f32(dout_ptr + 4, vsum2);
vst1q_f32(dout_ptr + 8, vsum3);
vst1q_f32(dout_ptr + 12, vsum4);
dout_ptr += 16;
}
for (int j = 0; j < remain; j++) {
*dout_ptr = *din_ptr * scale + bias;
dout_ptr++;
din_ptr++;
}
}
}
}
template <>
void scale<float>(const float* din, float* dout, int outer_dim, int scale_dim,
const float* scale_data, const float* bias_data) {
int cnt = scale_dim >> 4;
int remain = scale_dim % 16;
for (int n = 0; n < outer_dim; n++) {
const float* din_ptr_n = din + n * scale_dim;
float* dout_ptr_n = dout + n * scale_dim;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
int idx = i << 4;
const float* din_ptr = din_ptr_n + idx;
const float* scale_ptr = scale_data + idx;
const float* bias_ptr = bias_data + idx;
float* dout_ptr = dout_ptr_n + idx;
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t vscale0 = vld1q_f32(scale_ptr);
float32x4_t vbias0 = vld1q_f32(bias_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t vscale1 = vld1q_f32(scale_ptr + 4);
float32x4_t vbias1 = vld1q_f32(bias_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t vscale2 = vld1q_f32(scale_ptr + 8);
float32x4_t vbias2 = vld1q_f32(bias_ptr + 8);
float32x4_t vsum1 = vmlaq_f32(vbias0, din0, vscale0);
float32x4_t vsum2 = vmlaq_f32(vbias1, din1, vscale1);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
float32x4_t vscale3 = vld1q_f32(scale_ptr + 12);
float32x4_t vbias3 = vld1q_f32(bias_ptr + 12);
vst1q_f32(dout_ptr, vsum1);
vst1q_f32(dout_ptr + 4, vsum2);
float32x4_t vsum3 = vmlaq_f32(vbias2, din2, vscale2);
float32x4_t vsum4 = vmlaq_f32(vbias3, din3, vscale3);
vst1q_f32(dout_ptr + 8, vsum3);
vst1q_f32(dout_ptr + 12, vsum4);
}
int idx = cnt << 4;
const float* din_ptr = din_ptr_n + idx;
float* dout_ptr = dout_ptr_n + idx;
const float* scale_ptr = scale_data + idx;
const float* bias_ptr = bias_data + idx;
for (int j = 0; j < remain; j++) {
*dout_ptr = *din_ptr * (*scale_ptr) + (*bias_ptr);
dout_ptr++;
din_ptr++;
scale_ptr++;
bias_ptr++;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -22,6 +22,14 @@ namespace math {
template <typename T>
void scale(const T* din, T* dout, int num, float scale, float bias);
template <typename T>
void scale(const T* din, T* dout, int outer_dim, int scale_dim, int inner_dim,
const float* scale_data, const float* bias_data);
template <typename T>
void scale(const T* din, T* dout, int outer_dim, int scale_dim,
const float* scale_data, const float* bias_data);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -10,6 +10,7 @@ cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm)
......@@ -18,6 +19,7 @@ lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm mat
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm)
lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm)
lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm)
lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm)
lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm)
......@@ -30,6 +32,7 @@ set(arm_kernels
scale_compute_arm
softmax_compute_arm
conv_compute_arm
batch_norm_compute_arm
elementwise_add_compute_arm
pool_compute_arm
split_compute_arm
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void BatchNormCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
bool global_stats = param.is_test || param.use_global_stats;
if (global_stats) {
int64_t channel_size = 0;
switch (param.data_layout) {
case DATALAYOUT(kNCHW):
channel_size = x_dims[1];
break;
// case DATALAYOUT(kNHWC):
// channel_size = x_dims[x_dims.size() - 1];
// break;
default:
LOG(FATAL) << "Unknown storage order: "
<< DataLayoutToStr(param.data_layout);
break;
}
new_scale.Resize({channel_size});
new_bias.Resize({channel_size});
auto* scale_data = param.scale->mutable_data<float>();
auto* bias_data = param.bias->mutable_data<float>();
auto* mean_data = param.mean->mutable_data<float>();
auto* variance_data = param.variance->mutable_data<float>();
auto* new_scale_data = new_scale.mutable_data<float>();
auto* new_bias_data = new_bias.mutable_data<float>();
for (int c = 0; c < channel_size; c++) {
float inv_scale = 1.f / (std::sqrt(variance_data[c] + param.epsilon));
new_bias_data[c] =
bias_data[c] - inv_scale * scale_data[c] * mean_data[c];
new_scale_data[c] = inv_scale * scale_data[c];
}
}
}
void BatchNormCompute::Run() {
auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
auto x_data = param.x->mutable_data<float>();
auto y_data = param.y->mutable_data<float>();
bool global_stats = param.is_test || param.use_global_stats;
if (global_stats) {
auto* new_scale_data = new_scale.mutable_data<float>();
auto* new_bias_data = new_bias.mutable_data<float>();
int64_t outer_size = 0;
int64_t channel_size = 0;
int64_t inner_size = 0;
switch (param.data_layout) {
case DATALAYOUT(kNCHW):
outer_size = x_dims[0];
channel_size = x_dims[1];
inner_size = x_dims.Slice(2, x_dims.size()).production();
lite::arm::math::scale(x_data, y_data, outer_size, channel_size,
inner_size, new_scale_data, new_bias_data);
break;
// case DATALAYOUT(kNHWC):
// outer_size = x_dims.Slice(0, x_dims.size() - 1).production();
// channel_size = x_dims[x_dims.size() - 1];
// lite::arm::math::scale(x_data, y_data, outer_size, channel_size,
// new_scale_data, new_bias_data);
// break;
default:
LOG(FATAL) << "Unknown storage order: "
<< DataLayoutToStr(param.data_layout);
break;
}
} else {
// TODO(hong19860320) calculate mean_out, variance_out, saved_mean and
// saved_variance
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::BatchNormCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Mean", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Variance", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("MeanOut", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("VarianceOut", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class BatchNormCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::BatchNormParam;
void PrepareForRun() override;
void Run() override;
virtual ~BatchNormCompute() = default;
private:
Tensor new_scale;
Tensor new_bias;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename dtype>
void batch_norm_compute_ref(const operators::BatchNormParam& param) {
DDim x_dims = param.x->dims();
auto x_data = param.x->mutable_data<dtype>();
auto scale_data = param.scale->mutable_data<dtype>();
auto bias_data = param.bias->mutable_data<dtype>();
auto mean_data = param.mean->mutable_data<dtype>();
auto variance_data = param.variance->mutable_data<dtype>();
auto y_data = param.y->mutable_data<dtype>();
float epsilon = param.epsilon;
float momentum = param.momentum;
DataLayoutType data_layout = param.data_layout;
bool global_stats = param.is_test || param.use_global_stats;
if (global_stats) {
int64_t outer_size = 0;
int64_t channel_size = 0;
int64_t inner_size = 0;
switch (data_layout) {
case DATALAYOUT(kNCHW):
outer_size = x_dims[0];
channel_size = x_dims[1];
inner_size = x_dims.Slice(2, x_dims.size()).production();
break;
// case DATALAYOUT(kNHWC):
// outer_size = x_dims.Slice(0, x_dims.size() - 1).production();
// channel_size = x_dims[x_dims.size() - 1];
// inner_size = 1;
// break;
default:
LOG(FATAL) << "Unknown storage order: " << DataLayoutToStr(data_layout);
break;
}
auto x_ptr = x_data;
auto y_ptr = y_data;
for (int o = 0; o < outer_size; o++) {
for (int c = 0; c < channel_size; c++) {
for (int i = 0; i < inner_size; i++) {
dtype norm_x =
(*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon);
*y_ptr = norm_x * scale_data[c] + bias_data[c];
x_ptr++;
y_ptr++;
}
}
}
} else {
// TODO(hong19860320) calculate mean_out, variance_out, saved_mean and
// saved_variance
}
}
TEST(batch_norm_arm, retrive_op) {
auto batch_norm =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"batch_norm");
ASSERT_FALSE(batch_norm.empty());
ASSERT_TRUE(batch_norm.front());
}
TEST(batch_norm_arm, init) {
BatchNormCompute batch_norm;
ASSERT_EQ(batch_norm.precision(), PRECISION(kFloat));
ASSERT_EQ(batch_norm.target(), TARGET(kARM));
}
TEST(batch_norm_arm, compute) {
DeviceInfo::Init();
for (auto n : {1, 2}) {
for (auto c : {6, 32 /*, 128*/}) {
for (auto h : {9, 18 /*, 56 , 112, 224, 512*/}) {
for (auto w : {9, 18 /*, 56, 112, 224, 512*/}) {
for (auto is_test : {/*false, */ true}) {
for (auto use_global_stats : {false, true}) {
for (auto epsilon : {1e-4f, 1e-5f}) {
for (auto momentum : {0.9f, 0.99f}) {
for (auto data_layout :
{DATALAYOUT(kNCHW) /*, DATALAYOUT(kNHWC)*/}) {
Tensor x;
Tensor scale;
Tensor bias;
Tensor mean;
Tensor variance;
Tensor y;
Tensor mean_out;
Tensor variance_out;
Tensor saved_mean;
Tensor saved_variance;
Tensor y_ref;
Tensor mean_out_ref;
Tensor variance_out_ref;
Tensor saved_mean_ref;
Tensor saved_variance_ref;
// set the dims of input, output, ref output tensors
std::vector<int64_t> in_out_shape;
switch (data_layout) {
case DATALAYOUT(kNCHW):
in_out_shape = {n, c, h, w};
break;
// case DATALAYOUT(kNHWC):
// in_out_shape = {n, h, w, c};
// break;
default:
LOG(FATAL) << "Unknown storage order: "
<< DataLayoutToStr(data_layout);
break;
}
x.Resize(in_out_shape);
scale.Resize({c});
bias.Resize({c});
mean.Resize({c});
variance.Resize({c});
y.Resize(in_out_shape);
mean_out.Resize({c});
variance_out.Resize({c});
saved_mean.Resize({c});
saved_variance.Resize({c});
y_ref.Resize(in_out_shape);
mean_out_ref.Resize({c});
variance_out_ref.Resize({c});
saved_mean_ref.Resize({c});
saved_variance_ref.Resize({c});
// initialize the data of input tensors
auto* x_data = x.mutable_data<float>();
auto* scale_data = scale.mutable_data<float>();
auto* bias_data = bias.mutable_data<float>();
auto* mean_data = mean.mutable_data<float>();
auto* variance_data = variance.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(i % 64);
}
for (int i = 0; i < scale.dims().production(); i++) {
scale_data[i] = static_cast<float>(i) * 0.01f + 0.03f;
}
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i) * 0.065f + 0.1f;
}
for (int i = 0; i < mean.dims().production(); i++) {
mean_data[i] = static_cast<float>(i) * 0.0565f;
}
for (int i = 0; i < variance.dims().production(); i++) {
variance_data[i] = static_cast<float>(i) * 2.08f + 1.5f;
}
// prepare kernel params and run
BatchNormCompute batch_norm;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
batch_norm.SetContext(std::move(ctx));
operators::BatchNormParam param;
param.x = &x;
param.scale = &scale;
param.bias = &bias;
param.mean = &mean;
param.variance = &variance;
param.is_test = is_test;
param.use_global_stats = use_global_stats;
param.epsilon = epsilon;
param.momentum = momentum;
param.data_layout = data_layout;
param.y = &y;
param.mean_out = &mean_out;
param.variance_out = &variance_out;
param.saved_mean = &saved_mean;
param.saved_variance = &saved_variance;
batch_norm.SetParam(param);
batch_norm.Launch();
// invoking ref implementation and compare results
param.y = &y_ref;
param.mean_out = &mean_out_ref;
param.variance_out = &variance_out_ref;
param.saved_mean = &saved_mean_ref;
param.saved_variance = &saved_variance_ref;
batch_norm_compute_ref<float>(param);
auto* y_ref_data = y_ref.mutable_data<float>();
for (int i = 0; i < y.dims().production(); i++) {
EXPECT_NEAR(y_data[i], y_ref_data[i], 1e-5);
}
}
}
}
}
}
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def);
......@@ -8,6 +8,7 @@ cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS})
cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS})
cc_library(reshape_op_lite SRCS reshape_op.cc DEPS ${op_DEPS} )
cc_library(batch_norm_op_lite SRCS batch_norm_op.cc DEPS ${op_DEPS})
cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS})
cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
......@@ -30,6 +31,7 @@ set(ops_lite
scale_op_lite
softmax_op_lite
reshape_op_lite
batch_norm_op_lite
feed_op_lite
fetch_op_lite
io_copy_op_lite
......@@ -52,4 +54,5 @@ lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc
lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite)
lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/batch_norm_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool BatchNormOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.bias);
CHECK_OR_FALSE(param_.scale);
CHECK_OR_FALSE(param_.mean);
CHECK_OR_FALSE(param_.variance);
CHECK_OR_FALSE(param_.y);
if (!param_.is_test) {
CHECK_OR_FALSE(param_.mean_out);
CHECK_OR_FALSE(param_.variance_out);
CHECK_OR_FALSE(param_.saved_mean);
CHECK_OR_FALSE(param_.saved_variance);
}
auto x_dims = param_.x->dims();
auto scale_dims = param_.scale->dims();
auto bias_dims = param_.bias->dims();
auto mean_dims = param_.mean->dims();
auto variance_dims = param_.variance->dims();
CHECK(x_dims.size() >= 2 && x_dims.size() <= 5)
<< "Input X must have 2 to 5 dimensions.";
CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions.";
CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions.";
CHECK_EQ(mean_dims.size(), 1UL) << "Input Mean must have 1 dimensions.";
CHECK_EQ(variance_dims.size(), 1UL)
<< "Input Variance must have 1 dimensions.";
return true;
}
bool BatchNormOp::InferShape() const {
auto x_dims = param_.x->dims();
int64_t channel_size = 0;
switch (param_.data_layout) {
case DATALAYOUT(kNCHW):
channel_size = x_dims[1];
break;
// case DATALAYOUT(kNHWC):
// channel_size = x_dims[x_dims.size() - 1];
// break;
default:
LOG(FATAL) << "Unknown storage order: "
<< DataLayoutToStr(param_.data_layout);
break;
}
if (!param_.is_test) {
param_.mean_out->Resize({channel_size});
param_.variance_out->Resize({channel_size});
param_.saved_mean->Resize({channel_size});
param_.saved_variance->Resize({channel_size});
}
param_.y->Resize(x_dims);
return true;
}
bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
param_.bias =
scope->FindVar(op_desc.Input("Bias").front())->GetMutable<Tensor>();
param_.scale =
scope->FindVar(op_desc.Input("Scale").front())->GetMutable<Tensor>();
param_.mean =
scope->FindVar(op_desc.Input("Mean").front())->GetMutable<Tensor>();
param_.variance =
scope->FindVar(op_desc.Input("Variance").front())->GetMutable<Tensor>();
param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
param_.is_test = op_desc.GetAttr<bool>("is_test");
param_.use_global_stats = op_desc.GetAttr<bool>("use_global_stats");
if (!param_.is_test) {
param_.mean_out =
scope->FindVar(op_desc.Output("MeanOut").front())->GetMutable<Tensor>();
param_.variance_out = scope->FindVar(op_desc.Output("VarianceOut").front())
->GetMutable<Tensor>();
param_.saved_mean = scope->FindVar(op_desc.Output("SavedMean").front())
->GetMutable<Tensor>();
param_.saved_variance =
scope->FindVar(op_desc.Output("SavedVariance").front())
->GetMutable<Tensor>();
}
param_.epsilon = op_desc.GetAttr<float>("epsilon");
param_.momentum = op_desc.GetAttr<float>("momentum");
std::string data_layout = op_desc.GetAttr<std::string>("data_layout");
CHECK_EQ(data_layout, "NCHW") << "TODO(hong19860320): Only support NCHW.";
// param_.data_layout = StringToDataLayout(data_layout);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(batch_norm, paddle::lite::operators::BatchNormOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class BatchNormOp : public OpLite {
public:
BatchNormOp() {}
explicit BatchNormOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "batch_norm"; }
private:
mutable BatchNormParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/batch_norm_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(batch_norm_op_lite, test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* scale = scope.Var("scale")->GetMutable<Tensor>();
auto* bias = scope.Var("bias")->GetMutable<Tensor>();
auto* mean = scope.Var("mean")->GetMutable<Tensor>();
auto* variance = scope.Var("variance")->GetMutable<Tensor>();
auto* y = scope.Var("y")->GetMutable<Tensor>();
x->Resize({2, 32, 10, 20});
auto x_dims = x->dims();
const int64_t channel_size = x_dims[1]; // NCHW
scale->Resize({channel_size});
bias->Resize({channel_size});
mean->Resize({channel_size});
variance->Resize(DDim({channel_size}));
// prepare op desc
cpp::OpDesc desc;
desc.SetType("batch_norm");
desc.SetInput("X", {"x"});
desc.SetInput("Scale", {"scale"});
desc.SetInput("Bias", {"bias"});
desc.SetInput("Mean", {"mean"});
desc.SetInput("Variance", {"variance"});
desc.SetOutput("Y", {"y"});
desc.SetAttr("is_test", true);
desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f);
desc.SetAttr("data_layout", std::string("NCHW"));
BatchNormOp batch_norm("batch_norm");
batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
batch_norm.Attach(desc, &scope);
batch_norm.CheckShape();
batch_norm.InferShape();
// check output dims
auto y_dims = y->dims();
CHECK_EQ(y_dims.size(), x_dims.size());
for (size_t i = 0; i < y_dims.size(); i++) {
CHECK_EQ(y_dims[i], x_dims[i]);
}
}
TEST(batch_norm_op_lite, test_enable_is_test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* scale = scope.Var("scale")->GetMutable<Tensor>();
auto* bias = scope.Var("bias")->GetMutable<Tensor>();
auto* mean = scope.Var("mean")->GetMutable<Tensor>();
auto* variance = scope.Var("variance")->GetMutable<Tensor>();
auto* y = scope.Var("y")->GetMutable<Tensor>();
auto* mean_out = scope.Var("mean_out")->GetMutable<Tensor>();
auto* variance_out = scope.Var("variance_out")->GetMutable<Tensor>();
auto* saved_mean = scope.Var("saved_mean")->GetMutable<Tensor>();
auto* saved_variance = scope.Var("saved_variance")->GetMutable<Tensor>();
x->Resize({2, 32, 10, 20});
auto x_dims = x->dims();
const int64_t channel_size = x_dims[1]; // NCHW
scale->Resize({channel_size});
bias->Resize({channel_size});
mean->Resize({channel_size});
variance->Resize({channel_size});
// prepare op desc
cpp::OpDesc desc;
desc.SetType("batch_norm");
desc.SetInput("X", {"x"});
desc.SetInput("Scale", {"scale"});
desc.SetInput("Bias", {"bias"});
desc.SetInput("Mean", {"mean"});
desc.SetInput("Variance", {"variance"});
desc.SetOutput("Y", {"y"});
desc.SetOutput("MeanOut", {"mean_out"});
desc.SetOutput("VarianceOut", {"variance_out"});
desc.SetOutput("SavedMean", {"saved_mean"});
desc.SetOutput("SavedVariance", {"saved_variance"});
desc.SetAttr("is_test", false);
desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f);
desc.SetAttr("data_layout", std::string("NCHW"));
BatchNormOp batch_norm("batch_norm");
batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
batch_norm.Attach(desc, &scope);
batch_norm.CheckShape();
batch_norm.InferShape();
// check output dims
auto y_dims = y->dims();
CHECK_EQ(y_dims.size(), x_dims.size());
for (size_t i = 0; i < y_dims.size(); i++) {
CHECK_EQ(y_dims[i], x_dims[i]);
}
auto mean_out_dims = mean_out->dims();
auto variance_out_dims = variance_out->dims();
auto saved_mean_dims = saved_mean->dims();
auto saved_variance_dims = saved_variance->dims();
CHECK_EQ(mean_out_dims.size(), 1UL);
CHECK_EQ(variance_out_dims.size(), 1UL);
CHECK_EQ(saved_mean_dims.size(), 1UL);
CHECK_EQ(saved_variance_dims.size(), 1UL);
CHECK_EQ(mean_out_dims[0], channel_size);
CHECK_EQ(variance_out_dims[0], channel_size);
CHECK_EQ(saved_mean_dims[0], channel_size);
CHECK_EQ(saved_variance_dims[0], channel_size);
}
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -146,6 +146,25 @@ struct ConvParam {
std::string data_format{"Anylayout"};
};
// For BatchNorm op
struct BatchNormParam {
lite::Tensor* x{};
lite::Tensor* bias{};
lite::Tensor* scale{};
lite::Tensor* mean{};
lite::Tensor* variance{};
lite::Tensor* y{};
lite::Tensor* mean_out{};
lite::Tensor* variance_out{};
lite::Tensor* saved_mean{};
lite::Tensor* saved_variance{};
bool is_test{true};
bool use_global_stats{false};
float epsilon;
float momentum;
DataLayoutType data_layout{DATALAYOUT(kNCHW)};
};
// For Pooling op
struct PoolParam {
lite::Tensor* x{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册