未验证 提交 0a9937d2 编写于 作者: X XiangGao 提交者: GitHub

improve group norm cpu precision and performance (#33176)

* improve group norm cpu precision and performance

* add unit test to group norm
上级 387f2276
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <array>
#include <numeric>
#include <string> #include <string>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
...@@ -73,6 +75,11 @@ class GroupNormKernel : public framework::OpKernel<T> { ...@@ -73,6 +75,11 @@ class GroupNormKernel : public framework::OpKernel<T> {
auto* iter_y_data = y_data; auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) { for (int bid = 0; bid < x_dims[0]; bid++) {
for (int gid = 0; gid < groups; gid++) { for (int gid = 0; gid < groups; gid++) {
const int64_t M = 8;
std::array<T, M> x_mean_arr;
std::array<T, M> x_var_arr;
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
T x_mean = 0, x_var = 0; T x_mean = 0, x_var = 0;
int number = int number =
std::min(group_size, static_cast<int>(C - gid * group_size)); std::min(group_size, static_cast<int>(C - gid * group_size));
...@@ -83,7 +90,37 @@ class GroupNormKernel : public framework::OpKernel<T> { ...@@ -83,7 +90,37 @@ class GroupNormKernel : public framework::OpKernel<T> {
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
for (int cid = 0; cid < number; cid++) { for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; imid++, iter_x_data++) { int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M) {
// TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
x_mean_arr[0] += iter_x_data[0];
x_var_arr[0] += iter_x_data[0] * iter_x_data[0];
x_mean_arr[1] += iter_x_data[1];
x_var_arr[1] += iter_x_data[1] * iter_x_data[1];
x_mean_arr[2] += iter_x_data[2];
x_var_arr[2] += iter_x_data[2] * iter_x_data[2];
x_mean_arr[3] += iter_x_data[3];
x_var_arr[3] += iter_x_data[3] * iter_x_data[3];
x_mean_arr[4] += iter_x_data[4];
x_var_arr[4] += iter_x_data[4] * iter_x_data[4];
x_mean_arr[5] += iter_x_data[5];
x_var_arr[5] += iter_x_data[5] * iter_x_data[5];
x_mean_arr[6] += iter_x_data[6];
x_var_arr[6] += iter_x_data[6] * iter_x_data[6];
x_mean_arr[7] += iter_x_data[7];
x_var_arr[7] += iter_x_data[7] * iter_x_data[7];
}
x_mean =
std::accumulate(x_mean_arr.cbegin(), x_mean_arr.cend(), x_mean);
x_var =
std::accumulate(x_var_arr.cbegin(), x_var_arr.cend(), x_var);
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
for (; imid < imsize; imid++, iter_x_data++) {
x_mean += iter_x_data[0]; x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0]; x_var += iter_x_data[0] * iter_x_data[0];
} }
...@@ -91,7 +128,37 @@ class GroupNormKernel : public framework::OpKernel<T> { ...@@ -91,7 +128,37 @@ class GroupNormKernel : public framework::OpKernel<T> {
} else { } else {
for (int cid = 0; cid < number; cid++) { for (int cid = 0; cid < number; cid++) {
iter_x_data = tmp_x + cid; iter_x_data = tmp_x + cid;
for (int imid = 0; imid < imsize; imid++, iter_x_data += C) { int imid;
for (imid = 0; imid < imsize - (imsize % M);
imid += M, iter_x_data += M * C) {
// TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used
// in template class/function, before we complete high
// performance cpu vector extension, temporarily unrolling
// loop to get high precision and performance
x_mean_arr[0] += iter_x_data[0 * C];
x_var_arr[0] += iter_x_data[0 * C] * iter_x_data[0 * C];
x_mean_arr[1] += iter_x_data[1 * C];
x_var_arr[1] += iter_x_data[1 * C] * iter_x_data[1 * C];
x_mean_arr[2] += iter_x_data[2 * C];
x_var_arr[2] += iter_x_data[2 * C] * iter_x_data[2 * C];
x_mean_arr[3] += iter_x_data[3 * C];
x_var_arr[3] += iter_x_data[3 * C] * iter_x_data[3 * C];
x_mean_arr[4] += iter_x_data[4 * C];
x_var_arr[4] += iter_x_data[4 * C] * iter_x_data[4 * C];
x_mean_arr[5] += iter_x_data[5 * C];
x_var_arr[5] += iter_x_data[5 * C] * iter_x_data[5 * C];
x_mean_arr[6] += iter_x_data[6 * C];
x_var_arr[6] += iter_x_data[6 * C] * iter_x_data[6 * C];
x_mean_arr[7] += iter_x_data[7 * C];
x_var_arr[7] += iter_x_data[7 * C] * iter_x_data[7 * C];
}
x_mean =
std::accumulate(x_mean_arr.cbegin(), x_mean_arr.cend(), x_mean);
x_var =
std::accumulate(x_var_arr.cbegin(), x_var_arr.cend(), x_var);
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
for (; imid < imsize; imid++, iter_x_data += C) {
x_mean += iter_x_data[0]; x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0]; x_var += iter_x_data[0] * iter_x_data[0];
} }
...@@ -101,8 +168,8 @@ class GroupNormKernel : public framework::OpKernel<T> { ...@@ -101,8 +168,8 @@ class GroupNormKernel : public framework::OpKernel<T> {
x_mean /= number * imsize; x_mean /= number * imsize;
x_var /= number * imsize; x_var /= number * imsize;
x_var = x_var - x_mean * x_mean; x_var = std::max(x_var - x_mean * x_mean, T(0));
T var_inv = 1.0 / sqrt(x_var + epsilon); T var_inv = T(1) / std::sqrt(x_var + epsilon);
mean_data[bid * groups + gid] = x_mean; mean_data[bid * groups + gid] = x_mean;
var_data[bid * groups + gid] = x_var; var_data[bid * groups + gid] = x_var;
......
...@@ -53,6 +53,15 @@ class TestDygraphGroupNormv2(unittest.TestCase): ...@@ -53,6 +53,15 @@ class TestDygraphGroupNormv2(unittest.TestCase):
weight_attr=False, weight_attr=False,
bias_attr=False) bias_attr=False)
def test_nn_exception():
with fluid.dygraph.guard(p):
def attr_data_format():
out = paddle.nn.GroupNorm(
num_groups=2, num_channels=2, data_format="NHWC")
self.assertRaises(ValueError, attr_data_format)
x = np.random.randn(*shape).astype("float32") x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x) y1 = compute_v1(x)
y2 = compute_v2(x) y2 = compute_v2(x)
...@@ -61,6 +70,7 @@ class TestDygraphGroupNormv2(unittest.TestCase): ...@@ -61,6 +70,7 @@ class TestDygraphGroupNormv2(unittest.TestCase):
print("y1:", y1, "\ty2:", y2) print("y1:", y1, "\ty2:", y2)
self.assertTrue(result) self.assertTrue(result)
test_weight_bias_false() test_weight_bias_false()
test_nn_exception()
def test_static(self): def test_static(self):
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
......
...@@ -375,7 +375,7 @@ class GroupNorm(layers.Layer): ...@@ -375,7 +375,7 @@ class GroupNorm(layers.Layer):
self._num_channels = num_channels self._num_channels = num_channels
self._num_groups = num_groups self._num_groups = num_groups
if data_format != 'NCHW': if data_format != 'NCHW':
raise ValueError("unsupported data layout:" + data_layout) raise ValueError("unsupported data layout:" + data_format)
param_shape = [self._num_channels] param_shape = [self._num_channels]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册