提交 6772190f 编写于 作者: L liuqi

Optimize a computation of batch_norm kernel and ops' BUILD file.

上级 303311af
......@@ -47,7 +47,7 @@ struct BatchNormFunctor : public BatchNormFunctorBase<D, T> {
new_offset = offset[c] - mean[c] * new_scale;
for (TIndex i = 0; i < n; ++i) {
TIndex pos = i * channel * sample_size + c * sample_size;
TIndex pos = (i * channel + c) * sample_size;
const T* input_sample_ptr = input + pos;
T* output_sample_ptr = output + pos;
for (TIndex j = 0; j < sample_size; ++j) {
......
......@@ -41,7 +41,7 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy
float32x4_t new_scale_f = vdupq_n_f32(new_scale);
float32x4_t new_offset_f = vdupq_n_f32(new_offset);
for (TIndex i = 0; i < n; ++i) {
TIndex pos = i * channel * sample_size + c * sample_size;
TIndex pos = (i * channel + c) * sample_size;
const float* input_sample_ptr = input + pos;
float* output_sample_ptr = output + pos;
......
......@@ -12,7 +12,7 @@ load("//mace:mace.bzl", "if_android")
cc_library(
name = "test",
testonly = 1,
srcs = [
hdrs = [
"ops_test_util.h",
],
deps = [
......@@ -23,14 +23,14 @@ cc_library(
cc_library(
name = "ops",
srcs = [
"batch_norm.cc",
"relu.cc",
],
hdrs = glob([
"relu.h",
"batch_norm.h",
]),
srcs = glob(
["*.cc"],
exclude = ["*_test.cc"],
),
hdrs = glob(
["*.h"],
exclude = ["ops_test_util.h"],
),
copts = ["-std=c++11"],
deps = [
"//mace/core",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册