diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 671233b43b84ebd43611bc850a9780e86c8df7f5..0bdaee1b0618434355fd0f3d17d98fc58808fa51 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -170,7 +170,7 @@ paddle.fluid.layers.beam_search (ArgSpec(args=['pre_ids', 'pre_scores', 'ids', ' paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)), ('document', '1d8a1c8b686b55631ba1b77805e4eacf')) paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '2c4d1ae83da6ed35e3b36ba1b3b51d23')) paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', '79797f827d89ae72c77960e9696883a9')) -paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '96b24820e8863d6044d5be4eaaddb9fd')) +paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '65231cc8281815124934b1439fbb750c')) paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '9461e67095a6fc5d568fb2ce8fef66ff')) paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax', 'axis'], varargs=None, keywords=None, defaults=(False, -100, True, False, -1)), ('document', '54e1675aa0364f4a78fa72804ec0f413')) paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'ecb75c1b00c4c76c98b482f633b7a10c')) diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 92772f2bc39321e28d091beeff986fb09d259432..e184ff14a5534dc40e87af0be45ca3409f1bdb18 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -38,9 +38,11 @@ class GroupNormOp : public framework::OperatorWithKernel { "Output(Mean) of GroupNormOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Variance"), "Output(Variance) of GroupNormOp should not be null."); - auto x_dim = ctx->GetInputDim("X"); - auto channel_num = x_dim[1]; + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + const int64_t channel_num = + (data_layout == DataLayout::kNCHW ? x_dim[1] : x_dim[x_dim.size() - 1]); auto batch_size = x_dim[0]; auto groups = ctx->Attrs().Get("groups"); PADDLE_ENFORCE_LE( @@ -91,7 +93,9 @@ class GroupNormOpMaker : public framework::OpProtoAndCheckerMaker { .AddCustomChecker([](const int &groups) { PADDLE_ENFORCE_GT(groups, 0, "'groups' should be greater than zero."); }); - + AddAttr("data_layout", + "An optional string from: \"NHWC\", \"NCHW\". ") + .SetDefault("NCHW"); AddComment(R"DOC( Group Normalization diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index 3bf8586254e9867c7f5151178db866655df11535..b7f79be45be84f2557c34300922506a5840c5dd5 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -19,6 +19,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using DataLayout = framework::DataLayout; enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 }; #define CHECK_CASE(i, flags, kernel_name, ...) \ @@ -45,18 +46,27 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { } template -__global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, +__global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W, int imsize, int groups, - int group_size, T* mean, T* var) { + int group_size, T* mean, T* var, + const DataLayout data_layout) { int gid = blockIdx.y; int cid = blockIdx.x; int bid = blockIdx.z; + int H = imsize / W; int number = min(group_size, static_cast(C - gid * group_size)); int ccid = gid * group_size + cid; if (ccid >= C) return; T x_mean = 0, x_var = 0; for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { - T val = x[(bid * C + ccid) * imsize + imid]; + T val; + if (data_layout == DataLayout::kNCHW) { + val = x[(bid * C + ccid) * imsize + imid]; + } else { + int hid = imid / W; + int wid = imid % W; + val = x[(bid * H + hid) * W * C + wid * C + ccid]; + } x_mean += val; x_var += val * val; } @@ -69,11 +79,13 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, template __global__ void GroupNormForward(const T* x, const T* mean, const T* var, const T* scale, const T* bias, int N, int C, - int imsize, int groups, int group_size, - T epsilon, T* y, T* real_var) { + int W, int imsize, int groups, int group_size, + T epsilon, T* y, T* real_var, + const DataLayout data_layout) { int gid = blockIdx.y; int cid = blockIdx.x; int bid = blockIdx.z; + int H = imsize / W; int ccid = gid * group_size + cid; if (ccid >= C) return; T x_mean = mean[bid * groups + gid]; @@ -82,11 +94,23 @@ __global__ void GroupNormForward(const T* x, const T* mean, const T* var, T var_inv = 1.0 / sqrt(x_var + epsilon); if (cid == 0 && threadIdx.x == 0) real_var[bid * groups + gid] = x_var; for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { - T val = x[(bid * C + ccid) * imsize + imid]; + T val; + int hid, wid; + if (data_layout == DataLayout::kNCHW) { + val = x[(bid * C + ccid) * imsize + imid]; + } else { + hid = imid / W; + wid = imid % W; + val = x[(bid * H + hid) * W * C + wid * C + ccid]; + } val = (val - x_mean) * var_inv; if (flags & kHasScale) val *= scale[gid * group_size + cid]; if (flags & kHasBias) val += bias[gid * group_size + cid]; - y[(bid * C + ccid) * imsize + imid] = val; + if (data_layout == DataLayout::kNCHW) { + y[(bid * C + ccid) * imsize + imid] = val; + } else { + y[(bid * H + hid) * W * C + wid * C + ccid] = val; + } } } @@ -95,6 +119,9 @@ class GroupNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); const float epsilon = ctx.Attr("epsilon"); auto* scale = ctx.Input("Scale"); auto* bias = ctx.Input("Bias"); @@ -106,7 +133,13 @@ class GroupNormKernel const auto groups = ctx.Attr("groups"); const auto x_dims = x->dims(); - const int group_size = (x_dims[1] - 1) / groups + 1; + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int group_size = (C - 1) / groups + 1; + const int W = + (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] + : x_dims[x_dims.size() - 2]); y->mutable_data(ctx.GetPlace()); mean->mutable_data(ctx.GetPlace()); @@ -130,31 +163,32 @@ class GroupNormKernel const T* bias_data = nullptr; if (bias) bias_data = bias->data(); - int imsize = x_dims[2] * x_dims[3]; + int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] + : x_dims[1] * x_dims[2]); + int block_size = std::min(1024, imsize); dim3 grid(group_size, groups, x_dims[0]); dim3 threads(block_size, 1, 1); GroupNormForwardGetMeanAndVar<<>>( - x_data, x_dims[0], x_dims[1], imsize, groups, group_size, mean_data, - temp_var_data); + x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data, + temp_var_data, data_layout); int flags = (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; UNROLL_ALL_CASES(flags, GroupNormForward, x_data, mean_data, temp_var_data, - scale_data, bias_data, x_dims[0], x_dims[1], imsize, - groups, group_size, epsilon, y_data, var_data); + scale_data, bias_data, x_dims[0], C, W, imsize, groups, + group_size, epsilon, y_data, var_data, data_layout); } }; template -__global__ void GroupNormBackwardGetMeanAndVar(const T* x, const T* scale, - const T* bias, const T* d_y, - int N, int C, int imsize, - int groups, int group_size, - T epsilon, T* d_mean, T* d_var, - T* d_scale, T* d_bias) { +__global__ void GroupNormBackwardGetMeanAndVar( + const T* x, const T* scale, const T* bias, const T* d_y, int N, int C, + int W, int imsize, int groups, int group_size, T epsilon, T* d_mean, + T* d_var, T* d_scale, T* d_bias, const DataLayout data_layout) { int gid = blockIdx.y; int cid = blockIdx.x; int bid = blockIdx.z; + int H = imsize / W; int number = min(group_size, static_cast(C - gid * group_size)); int ccid = gid * group_size + cid; if (ccid >= C) return; @@ -165,8 +199,16 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x, const T* scale, T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0; for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { - T val = x[(bid * C + ccid) * imsize + imid] - x_bias; - T dval = d_y[(bid * C + ccid) * imsize + imid]; + T val, dval; + if (data_layout == DataLayout::kNCHW) { + val = x[(bid * C + ccid) * imsize + imid] - x_bias; + dval = d_y[(bid * C + ccid) * imsize + imid]; + } else { + int hid = imid / W; + int wid = imid % W; + val = x[(bid * H + hid) * W * C + wid * C + ccid] - x_bias; + dval = d_y[(bid * H + hid) * W * C + wid * C + ccid]; + } d_var_data += val * dval; d_mean_data += dval * x_scale; @@ -184,12 +226,14 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x, const T* scale, template __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale, const T* bias, const T* var, const T* d_mean, - const T* d_var, int N, int C, int imsize, - int groups, int group_size, T epsilon, - T* d_x) { + const T* d_var, int N, int C, int W, + int imsize, int groups, int group_size, + T epsilon, T* d_x, + const DataLayout data_layout) { int gid = blockIdx.y; int cid = blockIdx.x; int bid = blockIdx.z; + int H = imsize / W; int number = min(group_size, static_cast(C - gid * group_size)); int ccid = gid * group_size + cid; if (ccid >= C) return; @@ -206,12 +250,23 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale, if (x_scale != 0) x_scale_inv = 1.0 / x_scale; for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { - T tmp = x[(bid * C + ccid) * imsize + imid]; - T v_y = (tmp - x_bias) * x_scale_inv; - T dly = d_y[(bid * C + ccid) * imsize + imid]; - d_x[(bid * C + ccid) * imsize + imid] = - x_var_inv * - (dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean); + if (data_layout == DataLayout::kNCHW) { + T tmp = x[(bid * C + ccid) * imsize + imid]; + T v_y = (tmp - x_bias) * x_scale_inv; + T dly = d_y[(bid * C + ccid) * imsize + imid]; + d_x[(bid * C + ccid) * imsize + imid] = + x_var_inv * + (dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean); + } else { + int hid = imid / W; + int wid = imid % W; + T tmp = x[(bid * H + hid) * W * C + wid * C + ccid]; + T v_y = (tmp - x_bias) * x_scale_inv; + T dly = d_y[(bid * H + hid) * W * C + wid * C + ccid]; + d_x[(bid * H + hid) * W * C + wid * C + ccid] = + x_var_inv * + (dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean); + } } } @@ -220,6 +275,9 @@ class GroupNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); const float epsilon = ctx.Attr("epsilon"); auto* x = ctx.Input("Y"); auto* var = ctx.Input("Variance"); @@ -234,7 +292,13 @@ class GroupNormGradKernel auto* d_bias = ctx.Output(framework::GradVarName("Bias")); const auto& x_dims = x->dims(); - const int group_size = (x_dims[1] - 1) / groups + 1; + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int group_size = (C - 1) / groups + 1; + const int W = + (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] + : x_dims[x_dims.size() - 2]); d_x->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; @@ -273,21 +337,23 @@ class GroupNormGradKernel const T* bias_data = nullptr; if (bias) bias_data = bias->data(); - int imsize = x_dims[2] * x_dims[3]; + int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] + : x_dims[1] * x_dims[2]); + int block_size = std::min(1024, imsize); dim3 grid(group_size, groups, x_dims[0]); dim3 threads(block_size, 1, 1); int flags = (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; UNROLL_ALL_CASES(flags, GroupNormBackwardGetMeanAndVar, x_data, scale_data, - bias_data, y_data, x_dims[0], x_dims[1], imsize, groups, + bias_data, y_data, x_dims[0], C, W, imsize, groups, group_size, epsilon, temp_mean_data, temp_var_data, - d_scale_data, d_bias_data); + d_scale_data, d_bias_data, data_layout); if (d_x_data != nullptr) { UNROLL_ALL_CASES(flags, GroupNormBackward, x_data, y_data, scale_data, bias_data, var_data, temp_mean_data, temp_var_data, - x_dims[0], x_dims[1], imsize, groups, group_size, - epsilon, d_x_data); + x_dims[0], C, W, imsize, groups, group_size, epsilon, + d_x_data, data_layout); } } }; diff --git a/paddle/fluid/operators/group_norm_op.h b/paddle/fluid/operators/group_norm_op.h index 498e65f614925f746dacacd3453e046a08ff5494..d4a1b3f036bba7eb193e5854cff2c239be18425c 100644 --- a/paddle/fluid/operators/group_norm_op.h +++ b/paddle/fluid/operators/group_norm_op.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once #include +#include +#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" @@ -31,6 +33,9 @@ template class GroupNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); const float epsilon = ctx.Attr("epsilon"); auto* scale = ctx.Input("Scale"); auto* bias = ctx.Input("Bias"); @@ -42,7 +47,10 @@ class GroupNormKernel : public framework::OpKernel { const auto groups = ctx.Attr("groups"); const auto x_dims = x->dims(); - const int group_size = (x_dims[1] - 1) / groups + 1; + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int group_size = (C - 1) / groups + 1; y->mutable_data(ctx.GetPlace()); mean->mutable_data(ctx.GetPlace()); @@ -58,36 +66,75 @@ class GroupNormKernel : public framework::OpKernel { const T* bias_data = nullptr; if (bias) bias_data = bias->data(); - int imsize = x_dims[2] * x_dims[3]; + int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] + : x_dims[1] * x_dims[2]); + auto* iter_x_data = x_data; auto* iter_y_data = y_data; - for (int bid = 0; bid < x_dims[0]; bid++) + for (int bid = 0; bid < x_dims[0]; bid++) { for (int gid = 0; gid < groups; gid++) { T x_mean = 0, x_var = 0; - int number = std::min(group_size, - static_cast(x_dims[1] - gid * group_size)); - auto* tmp = iter_x_data; - for (int cid = 0; cid < number; cid++) { - for (int imid = 0; imid < imsize; imid++, iter_x_data++) { - x_mean += iter_x_data[0]; - x_var += iter_x_data[0] * iter_x_data[0]; + int number = + std::min(group_size, static_cast(C - gid * group_size)); + auto* tmp_x = iter_x_data; + auto* x_src_data = iter_x_data; + auto* tmp_y = iter_y_data; + auto* y_src_data = iter_y_data; + + if (data_layout == DataLayout::kNCHW) { + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; imid++, iter_x_data++) { + x_mean += iter_x_data[0]; + x_var += iter_x_data[0] * iter_x_data[0]; + } + } + } else { + for (int cid = 0; cid < number; cid++) { + iter_x_data = tmp_x + cid; + for (int imid = 0; imid < imsize; imid++, iter_x_data += C) { + x_mean += iter_x_data[0]; + x_var += iter_x_data[0] * iter_x_data[0]; + } } + iter_x_data = tmp_x + group_size; } + x_mean /= number * imsize; x_var /= number * imsize; x_var = x_var - x_mean * x_mean; T var_inv = 1.0 / sqrt(x_var + epsilon); mean_data[bid * groups + gid] = x_mean; var_data[bid * groups + gid] = x_var; - for (int cid = 0; cid < number; cid++) { - for (int imid = 0; imid < imsize; imid++, tmp++, iter_y_data++) { - T val = (tmp[0] - x_mean) * var_inv; - if (scale_data) val *= scale_data[gid * group_size + cid]; - if (bias_data) val += bias_data[gid * group_size + cid]; - iter_y_data[0] = val; + + if (data_layout == DataLayout::kNCHW) { + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; imid++, tmp_x++, iter_y_data++) { + T val = (tmp_x[0] - x_mean) * var_inv; + if (scale_data) val *= scale_data[gid * group_size + cid]; + if (bias_data) val += bias_data[gid * group_size + cid]; + iter_y_data[0] = val; + } } + } else { + for (int cid = 0; cid < number; cid++) { + tmp_x = x_src_data + cid; + iter_y_data = y_src_data + cid; + for (int imid = 0; imid < imsize; + imid++, tmp_x += C, iter_y_data += C) { + T val = (tmp_x[0] - x_mean) * var_inv; + if (scale_data) val *= scale_data[gid * group_size + cid]; + if (bias_data) val += bias_data[gid * group_size + cid]; + iter_y_data[0] = val; + } + } + iter_y_data = tmp_y + group_size; } } + if (data_layout == DataLayout::kNHWC) { + iter_x_data = x_data + (bid + 1) * C * imsize; + iter_y_data = y_data + (bid + 1) * C * imsize; + } + } } }; @@ -95,6 +142,9 @@ template class GroupNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); const float epsilon = ctx.Attr("epsilon"); auto* x = ctx.Input("Y"); auto* var = ctx.Input("Variance"); @@ -109,7 +159,10 @@ class GroupNormGradKernel : public framework::OpKernel { auto* d_bias = ctx.Output(framework::GradVarName("Bias")); const auto& x_dims = x->dims(); - const int group_size = (x_dims[1] - 1) / groups + 1; + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int group_size = (C - 1) / groups + 1; d_x->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; @@ -137,54 +190,112 @@ class GroupNormGradKernel : public framework::OpKernel { const T* bias_data = nullptr; if (bias) bias_data = bias->data(); - int imsize = x_dims[2] * x_dims[3]; + int imsize = (data_layout == DataLayout::kNCHW ? x_dims[2] * x_dims[3] + : x_dims[1] * x_dims[2]); auto* iter_x_data = x_data; auto* iter_d_x_data = d_x_data; auto* iter_y_data = y_data; - for (int bid = 0; bid < x_dims[0]; bid++) + for (int bid = 0; bid < x_dims[0]; bid++) { for (int gid = 0; gid < groups; gid++) { T x_var = var_data[bid * groups + gid]; T var_inv = 1.0 / sqrt(x_var + epsilon); - int number = std::min(group_size, - static_cast(x_dims[1] - gid * group_size)); + int number = + std::min(group_size, static_cast(C - gid * group_size)); T number_inv = 1.0 / (number * imsize); - auto* iter_x_data2 = iter_x_data; - auto* iter_y_data2 = iter_y_data; + auto* tmp_x = iter_x_data; + auto* tmp_y = iter_y_data; + auto* tmp_d_x = iter_d_x_data; + auto* x_src_data = iter_x_data; + auto* y_src_data = iter_y_data; + auto* iter_x_data_backup = iter_x_data; + auto* iter_y_data_backup = iter_y_data; + auto* iter_d_x_data_backup = iter_d_x_data; T dp_scale = 0, dp_bias = 0; - for (int cid = 0; cid < number; cid++) { - for (int imid = 0; imid < imsize; - imid++, iter_x_data++, iter_y_data++) { - T val = iter_x_data[0]; - if (bias_data) val -= bias_data[gid * group_size + cid]; - T dval = iter_y_data[0]; - dp_scale += val * dval; - dp_bias += dval * scale_data[gid * group_size + cid]; - - if (scale_data && scale_data[gid * group_size + cid] != 0) - val /= scale_data[gid * group_size + cid]; - if (d_bias_data) d_bias_data[gid * group_size + cid] += dval; - if (d_scale_data) - d_scale_data[gid * group_size + cid] += val * dval; + + if (data_layout == DataLayout::kNCHW) { + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; + imid++, iter_x_data++, iter_y_data++) { + T val = iter_x_data[0]; + if (bias_data) val -= bias_data[gid * group_size + cid]; + T dval = iter_y_data[0]; + dp_scale += val * dval; + dp_bias += dval * scale_data[gid * group_size + cid]; + + if (scale_data && scale_data[gid * group_size + cid] != 0) + val /= scale_data[gid * group_size + cid]; + if (d_bias_data) d_bias_data[gid * group_size + cid] += dval; + if (d_scale_data) + d_scale_data[gid * group_size + cid] += val * dval; + } } - } - for (int cid = 0; cid < number; cid++) { - for (int imid = 0; imid < imsize; - imid++, iter_d_x_data++, iter_x_data2++, iter_y_data2++) { - T v_y = iter_x_data2[0]; - T dly = iter_y_data2[0]; - T dss = dp_scale; - T dbs = dp_bias; - T v_scale = scale_data[gid * group_size + cid]; - T v_bias = bias_data[gid * group_size + cid]; - v_y -= v_bias; - if (v_scale != 0) v_y /= v_scale; - iter_d_x_data[0] = - (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) * - var_inv; + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; + imid++, iter_d_x_data++, tmp_x++, tmp_y++) { + T v_y = tmp_x[0]; + T dly = tmp_y[0]; + T dss = dp_scale; + T dbs = dp_bias; + T v_scale = scale_data[gid * group_size + cid]; + T v_bias = bias_data[gid * group_size + cid]; + v_y -= v_bias; + if (v_scale != 0) v_y /= v_scale; + iter_d_x_data[0] = + (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) * + var_inv; + } + } + } else { + for (int cid = 0; cid < number; cid++) { + iter_x_data = x_src_data + cid; + iter_y_data = y_src_data + cid; + for (int imid = 0; imid < imsize; + imid++, iter_x_data += C, iter_y_data += C) { + T val = iter_x_data[0]; + if (bias_data) val -= bias_data[gid * group_size + cid]; + T dval = iter_y_data[0]; + dp_scale += val * dval; + dp_bias += dval * scale_data[gid * group_size + cid]; + + if (scale_data && scale_data[gid * group_size + cid] != 0) + val /= scale_data[gid * group_size + cid]; + if (d_bias_data) d_bias_data[gid * group_size + cid] += dval; + if (d_scale_data) + d_scale_data[gid * group_size + cid] += val * dval; + } } + + for (int cid = 0; cid < number; cid++) { + tmp_x = x_src_data + cid; + tmp_y = y_src_data + cid; + iter_d_x_data = tmp_d_x + cid; + for (int imid = 0; imid < imsize; + imid++, iter_d_x_data += C, tmp_x += C, tmp_y += C) { + T v_y = tmp_x[0]; + T dly = tmp_y[0]; + T dss = dp_scale; + T dbs = dp_bias; + T v_scale = scale_data[gid * group_size + cid]; + T v_bias = bias_data[gid * group_size + cid]; + v_y -= v_bias; + if (v_scale != 0) v_y /= v_scale; + iter_d_x_data[0] = + (dly * v_scale - number_inv * dss * v_y - number_inv * dbs) * + var_inv; + } + } + iter_x_data = iter_x_data_backup + group_size; + iter_y_data = iter_y_data_backup + group_size; + iter_d_x_data = iter_d_x_data_backup + group_size; } } + if (data_layout == DataLayout::kNHWC) { + iter_x_data = x_data + (bid + 1) * C * imsize; + iter_d_x_data = d_x_data + (bid + 1) * C * imsize; + iter_y_data = y_data + (bid + 1) * C * imsize; + } + } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index aa3dea655bfe17ce4940ea8ceb142ed851b581b1..28d7cfef71d7a0de3c7d419f9dbd6457b622ec72 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3764,7 +3764,7 @@ def group_norm(input, bias :math:`b`. If it is set to False, no bias will be added to the output units. If it is set to None, the bias is initialized zero. Default: None. act(str): Activation to be applied to the output of group normalizaiton. - data_layout(string|NCHW): Only NCHW is supported. + data_layout(string, default NCHW): NCHW(num_batch, channels, h, w) or NHWC(num_batch, h, w, channels). name (str): The name of this layer. It is optional. Returns: @@ -3783,9 +3783,12 @@ def group_norm(input, # create intput and parameters inputs = {'X': input} input_shape = input.shape - if data_layout != 'NCHW': - raise ValueError("unsupported data layout:" + data_layout) - param_shape = [input_shape[1]] + if data_layout != 'NCHW' and data_layout != 'NHWC': + raise ValueError( + "Param(data_layout) of Op(fluid.layers.group_norm) got wrong value: received " + + data_layout + " but only NCHW or NHWC supported.") + channel_num = input_shape[1] if data_layout == 'NCHW' else input_shape[-1] + param_shape = [channel_num] if param_attr: scale = helper.create_parameter( attr=helper.param_attr, @@ -3811,8 +3814,11 @@ def group_norm(input, "Mean": mean_out, "Variance": variance_out, }, - attrs={"epsilon": epsilon, - "groups": groups}) + attrs={ + "epsilon": epsilon, + "groups": groups, + "data_layout": data_layout + }) return helper.append_activation(group_norm_out) diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op.py b/python/paddle/fluid/tests/unittests/test_group_norm_op.py index 386c3b1f0e438dc50943009f0fe8663838a32ecc..7fcde530fe99a2ee910f5e58780ea2682f18c797 100644 --- a/python/paddle/fluid/tests/unittests/test_group_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op.py @@ -24,7 +24,9 @@ from op_test import OpTest from testsuite import create_op -def group_norm_naive(x, scale, bias, epsilon, groups): +def group_norm_naive(x, scale, bias, epsilon, groups, data_layout): + if data_layout == "NHWC": + x = np.transpose(x, (0, 3, 1, 2)) # NHWC => NCHW N, C, H, W = x.shape G = groups x = x.reshape((N * G, -1)) @@ -33,6 +35,8 @@ def group_norm_naive(x, scale, bias, epsilon, groups): output = (x - mean) / np.sqrt(var + epsilon) output = output.reshape((N, C, H, W)) * scale.reshape( (-1, 1, 1)) + bias.reshape((-1, 1, 1)) + if data_layout == "NHWC": + output = np.transpose(output, (0, 2, 3, 1)) # NCHW => NHWC return output, mean.reshape((N, G)), var.reshape((N, G)) @@ -42,15 +46,18 @@ class TestGroupNormOp(OpTest): self.data_format = "NCHW" self.dtype = np.float32 self.shape = (2, 4, 3, 3) - self.attrs = {'epsilon': 1e-5, 'groups': 2} + self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NCHW"} self.compare_between_place = False self.init_test_case() input = np.random.random(self.shape).astype(self.dtype) + if self.data_format == "NHWC": + input = np.transpose(input, (0, 2, 3, 1)) scale = np.random.random([self.shape[1]]).astype(self.dtype) bias = np.random.random([self.shape[1]]).astype(self.dtype) output, mean, var = group_norm_naive( - input, scale, bias, self.attrs['epsilon'], self.attrs['groups']) + input, scale, bias, self.attrs['epsilon'], self.attrs['groups'], + self.data_format) self.inputs = { 'X': OpTest.np_dtype_to_fluid_dtype(input), @@ -58,6 +65,7 @@ class TestGroupNormOp(OpTest): 'Bias': OpTest.np_dtype_to_fluid_dtype(bias) } self.outputs = {'Y': output, 'Mean': mean, 'Variance': var} + self.attrs['data_layout'] = self.data_format def test_check_output(self): atol = 1e-4 @@ -66,6 +74,7 @@ class TestGroupNormOp(OpTest): # add inplace_atol bacause group_norm doesn't ensure computational consistency self.check_output_with_place( place, atol=atol, inplace_atol=inplace_atol) + if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_output_with_place( @@ -94,6 +103,7 @@ class TestGroupNormOp(OpTest): if self.compare_between_place: self.do_compare_between_place() return + place = core.CPUPlace() self.check_grad_with_place( place, set(['X', 'Scale', 'Bias']), 'Y', max_relative_error=0.01) @@ -143,5 +153,85 @@ class TestGroupNormOpLargeData(TestGroupNormOp): self.compare_between_place = True +class TestGroupNormOp1_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 1 + self.data_format = "NHWC" + + +class TestGroupNormOp2_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 4 + self.data_format = "NHWC" + + +class TestGroupNormOpBigEps1_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 1 + self.attrs['epsilon'] = 0.5 + self.data_format = "NHWC" + + +class TestGroupNormOpBigEps2_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 4 + self.attrs['epsilon'] = 0.5 + self.data_format = "NHWC" + + +class TestGroupNormOpBigEps3_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.attrs['epsilon'] = 0.5 + self.data_format = "NHWC" + + +class TestGroupNormOpLargeData_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.shape = (2, 64, 32, 32) # NCHW + self.attrs['groups'] = 8 + self.data_format = "NHWC" + self.compare_between_place = True + + +class TestGroupNormAPI_With_NHWC(OpTest): + def test_case1(self): + data1 = fluid.layers.data( + name='data1', shape=[3, 3, 4], dtype='float32') + out1 = fluid.layers.group_norm( + input=data1, groups=2, data_layout="NHWC") + data2 = fluid.layers.data( + name='data2', shape=[4, 3, 3], dtype='float32') + out2 = fluid.layers.group_norm( + input=data2, groups=2, data_layout="NCHW") + + data1_np = np.random.random((2, 3, 3, 4)).astype("float32") + data2_np = np.random.random((2, 4, 3, 3)).astype("float32") + scale = np.array([1]).astype("float32") + bias = np.array([0]).astype("float32") + + place = core.CPUPlace() + exe = fluid.Executor(place) + results = exe.run(fluid.default_main_program(), + feed={"data1": data1_np, + "data2": data2_np}, + fetch_list=[out1, out2], + return_numpy=True) + expect_res1 = group_norm_naive( + data1_np, scale, bias, epsilon=1e-5, groups=2, data_layout="NHWC") + expect_res2 = group_norm_naive( + data2_np, scale, bias, epsilon=1e-5, groups=2, data_layout="NCHW") + self.assertTrue(np.allclose(results[0], expect_res1[0])) + self.assertTrue(np.allclose(results[1], expect_res2[0])) + + # data_layout is not NHWC or NCHW + def test_case2(self): + data = fluid.layers.data(name='data', shape=[3, 3, 4], dtype="float32") + try: + out = fluid.layers.group_norm( + input=data, groups=2, data_layout="NDHW") + except: + pass + + if __name__ == '__main__': unittest.main()