提交 6772190f 编写于 作者: L liuqi

Optimize a computation of batch_norm kernel and ops' BUILD file.

上级 303311af
...@@ -47,7 +47,7 @@ struct BatchNormFunctor : public BatchNormFunctorBase<D, T> { ...@@ -47,7 +47,7 @@ struct BatchNormFunctor : public BatchNormFunctorBase<D, T> {
new_offset = offset[c] - mean[c] * new_scale; new_offset = offset[c] - mean[c] * new_scale;
for (TIndex i = 0; i < n; ++i) { for (TIndex i = 0; i < n; ++i) {
TIndex pos = i * channel * sample_size + c * sample_size; TIndex pos = (i * channel + c) * sample_size;
const T* input_sample_ptr = input + pos; const T* input_sample_ptr = input + pos;
T* output_sample_ptr = output + pos; T* output_sample_ptr = output + pos;
for (TIndex j = 0; j < sample_size; ++j) { for (TIndex j = 0; j < sample_size; ++j) {
......
...@@ -41,7 +41,7 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy ...@@ -41,7 +41,7 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy
float32x4_t new_scale_f = vdupq_n_f32(new_scale); float32x4_t new_scale_f = vdupq_n_f32(new_scale);
float32x4_t new_offset_f = vdupq_n_f32(new_offset); float32x4_t new_offset_f = vdupq_n_f32(new_offset);
for (TIndex i = 0; i < n; ++i) { for (TIndex i = 0; i < n; ++i) {
TIndex pos = i * channel * sample_size + c * sample_size; TIndex pos = (i * channel + c) * sample_size;
const float* input_sample_ptr = input + pos; const float* input_sample_ptr = input + pos;
float* output_sample_ptr = output + pos; float* output_sample_ptr = output + pos;
......
...@@ -12,7 +12,7 @@ load("//mace:mace.bzl", "if_android") ...@@ -12,7 +12,7 @@ load("//mace:mace.bzl", "if_android")
cc_library( cc_library(
name = "test", name = "test",
testonly = 1, testonly = 1,
srcs = [ hdrs = [
"ops_test_util.h", "ops_test_util.h",
], ],
deps = [ deps = [
...@@ -23,14 +23,14 @@ cc_library( ...@@ -23,14 +23,14 @@ cc_library(
cc_library( cc_library(
name = "ops", name = "ops",
srcs = [ srcs = glob(
"batch_norm.cc", ["*.cc"],
"relu.cc", exclude = ["*_test.cc"],
], ),
hdrs = glob([ hdrs = glob(
"relu.h", ["*.h"],
"batch_norm.h", exclude = ["ops_test_util.h"],
]), ),
copts = ["-std=c++11"], copts = ["-std=c++11"],
deps = [ deps = [
"//mace/core", "//mace/core",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册