提交 ae7d2286 编写于 作者: D Dun 提交者: qingqing01

Group Norm (#13843)

Add group normalization operator.
上级 de2db117
......@@ -25,6 +25,7 @@
| kexinzhao | Ke-Xin Zhao |
| kuke | Yi-Bing Liu |
| lcy-seso | Ying Cao |
| cjld | Dun Liang |
| lipeng-unisound | Peng Li |
| liuyuan | Yuan Liu |
| livc | Zhao Li |
......
......@@ -103,6 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
paddle.fluid.layers.group_norm ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, False, False))
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/group_norm_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
class GroupNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Mean"),
"Output(Mean) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Variance"),
"Output(Variance) of GroupNormOp should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto channel_num = x_dim[1];
auto batch_size = x_dim[0];
auto groups = ctx->Attrs().Get<int>("groups");
PADDLE_ENFORCE_LE(
groups, channel_num,
"'groups' must be less equal than the number of channels.");
PADDLE_ENFORCE_GE(groups, 1, "'groups' must be greater equal than 1.");
if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], channel_num);
}
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], channel_num);
}
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->SetOutputDim("Mean", {batch_size, groups});
ctx->SetOutputDim("Variance", {batch_size, groups});
ctx->ShareLoD("X", "Y");
}
};
class GroupNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddInput("Scale",
"Scale is a 1-dimensional tensor of size C"
"that is applied to the output.")
.AsDispensable();
AddInput("Bias",
"Bias is a 1-dimensional tensor of size C "
"that is applied to the output")
.AsDispensable();
AddOutput("Y", "Result after normalization.");
AddOutput("Mean", "Mean of each group.").AsIntermediate();
AddOutput("Variance", "Variance of each group.").AsIntermediate();
AddAttr<float>("epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 1.0f,
"'epsilon' should be between 0.0 and 1.0.");
});
AddAttr<int>("groups", "The number of groups that divided from channels.")
.AddCustomChecker([](const int &groups) {
PADDLE_ENFORCE_GT(groups, 0, "'groups' should be greater than zero.");
});
AddComment(R"DOC(
Group Normalization
Refer to `Group Normalization <https://arxiv.org/abs/1803.08494>`_
)DOC");
}
};
class GroupNormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
"Input(Variance) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) of GroupNormOp should not be null.");
// check output
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Scale"),
ctx->GetInputDim("Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp);
REGISTER_OP_CPU_KERNEL(
group_norm, ops::GroupNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::GroupNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
group_norm_grad,
ops::GroupNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GroupNormGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cub/cub.cuh>
#include "paddle/fluid/operators/group_norm_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C,
int imsize, int groups,
int group_size, T* mean, T* var) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_mean = 0, x_var = 0;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val = x[(bid * C + ccid) * imsize + imid];
x_mean += val;
x_var += val * val;
}
x_mean /= number * imsize;
x_var /= number * imsize;
__shared__ T s_mem[2];
if (threadIdx.x == 0) {
s_mem[0] = s_mem[1] = 0;
}
__syncthreads();
paddle::platform::CudaAtomicAdd(&s_mem[0], x_mean);
paddle::platform::CudaAtomicAdd(&s_mem[1], x_var);
__syncthreads();
if (threadIdx.x == 0) {
paddle::platform::CudaAtomicAdd(&mean[bid * groups + gid], s_mem[0]);
paddle::platform::CudaAtomicAdd(&var[bid * groups + gid], s_mem[1]);
}
}
template <typename T>
__global__ void GroupNormForward(const T* x, const T* mean, const T* var,
const T* scale, const T* bias, int N, int C,
int imsize, int groups, int group_size,
T epsilon, T* y, T* real_var) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_mean = mean[bid * groups + gid];
T x_var = var[bid * groups + gid];
x_var = x_var - x_mean * x_mean;
T var_inv = 1.0 / sqrt(x_var + epsilon);
if (cid == 0 && threadIdx.x == 0) real_var[bid * groups + gid] = x_var;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val = x[(bid * C + ccid) * imsize + imid];
val = (val - x_mean) * var_inv;
if (scale) val *= scale[gid * group_size + cid];
if (bias) val += bias[gid * group_size + cid];
y[(bid * C + ccid) * imsize + imid] = val;
}
}
template <typename T>
class GroupNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Y");
auto* mean = ctx.Output<Tensor>("Mean");
auto* var = ctx.Output<Tensor>("Variance");
const auto groups = ctx.Attr<int>("groups");
const auto x_dims = x->dims();
const int group_size = (x_dims[1] - 1) / groups + 1;
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Tensor temp_var;
temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());
set_zero(dev_ctx, mean, static_cast<T>(0));
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
auto* x_data = x->data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
auto* temp_var_data = temp_var.data<T>();
const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>();
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = x_dims[2] * x_dims[3];
int block_size = std::min(512, imsize);
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, x_dims[0], x_dims[1], imsize, groups, group_size, mean_data,
temp_var_data);
GroupNormForward<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, scale_data, bias_data, x_dims[0],
x_dims[1], imsize, groups, group_size, epsilon, y_data, var_data);
}
};
template <typename T>
__global__ void GroupNormBackwardGetMeanAndVar(
const T* x, const T* mean, const T* var, const T* scale, const T* d_y,
int N, int C, int imsize, int groups, int group_size, T epsilon, T* d_x,
T* d_mean, T* d_var, T* d_scale, T* d_bias) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_mean = mean[bid * groups + gid];
T x_var = var[bid * groups + gid];
T var_inv = 1.0 / sqrt(x_var + epsilon);
T d_var_inv = 0, d_x_mean = 0;
T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T tmp = x[(bid * C + ccid) * imsize + imid];
T val = (tmp - x_mean) * var_inv;
T dval = d_y[(bid * C + ccid) * imsize + imid];
if (d_bias) d_bias_data += dval;
if (d_scale) d_scale_data += val * dval;
if (scale) dval = dval * scale[ccid];
d_var_data += (tmp - x_mean) * dval;
T d_tmp = dval * var_inv;
if (d_x) d_x[(bid * C + ccid) * imsize + imid] = d_tmp;
d_mean_data -= d_tmp;
}
__shared__ T s_mem[4];
if (threadIdx.x == 0) {
s_mem[0] = s_mem[1] = 0;
if (d_scale) s_mem[2] = 0;
if (d_bias) s_mem[3] = 0;
}
__syncthreads();
paddle::platform::CudaAtomicAdd(&s_mem[0], d_mean_data);
paddle::platform::CudaAtomicAdd(&s_mem[1], d_var_data);
if (d_scale) paddle::platform::CudaAtomicAdd(&s_mem[2], d_scale_data);
if (d_bias) paddle::platform::CudaAtomicAdd(&s_mem[3], d_bias_data);
__syncthreads();
if (threadIdx.x == 0) {
paddle::platform::CudaAtomicAdd(&d_mean[bid * groups + gid], s_mem[0]);
paddle::platform::CudaAtomicAdd(&d_var[bid * groups + gid], s_mem[1]);
if (d_scale) paddle::platform::CudaAtomicAdd(&d_scale[ccid], s_mem[2]);
if (d_bias) paddle::platform::CudaAtomicAdd(&d_bias[ccid], s_mem[3]);
}
}
template <typename T>
__global__ void GroupNormBackward(const T* x, const T* mean, const T* var,
const T* d_mean, const T* d_var, int N, int C,
int imsize, int groups, int group_size,
T epsilon, T* d_x) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_mean = mean[bid * groups + gid];
T x_var = var[bid * groups + gid];
T d_x_mean = d_mean[bid * groups + gid];
T d_var_inv = d_var[bid * groups + gid];
T d_x_var =
-1.0 / (2 * (x_var + epsilon) * sqrt(x_var + epsilon)) * d_var_inv;
d_x_mean -= 2 * d_x_var * x_mean;
d_x_var /= number * imsize;
d_x_mean /= number * imsize;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T tmp = x[(bid * C + ccid) * imsize + imid];
if (d_x)
d_x[(bid * C + ccid) * imsize + imid] += d_x_mean + tmp * 2 * d_x_var;
}
}
template <typename T>
class GroupNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto* x = ctx.Input<Tensor>("X");
auto* mean = ctx.Input<Tensor>("Mean");
auto* var = ctx.Input<Tensor>("Variance");
auto* scale = ctx.Input<Tensor>("Scale");
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto groups = ctx.Attr<int>("groups");
// init output
auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto& x_dims = x->dims();
const int group_size = (x_dims[1] - 1) / groups + 1;
T* d_x_data = nullptr;
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
d_x_data = d_x->data<T>();
}
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Tensor temp_var;
temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
T* temp_var_data = temp_var.data<T>();
Tensor temp_mean;
temp_mean.mutable_data<T>(var->dims(), ctx.GetPlace());
set_zero(dev_ctx, &temp_mean, static_cast<T>(0));
T* temp_mean_data = temp_mean.data<T>();
auto* x_data = x->data<T>();
auto* y_data = d_y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
T* d_scale_data = nullptr;
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_scale, static_cast<T>(0));
d_scale_data = d_scale->data<T>();
}
T* d_bias_data = nullptr;
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_bias, static_cast<T>(0));
d_bias_data = d_bias->data<T>();
}
const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>();
int imsize = x_dims[2] * x_dims[3];
int block_size = std::min(512, imsize);
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
GroupNormBackwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, mean_data, var_data, scale_data, y_data, x_dims[0], x_dims[1],
imsize, groups, group_size, epsilon, d_x_data, temp_mean_data,
temp_var_data, d_scale_data, d_bias_data);
GroupNormBackward<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, mean_data, var_data, temp_mean_data, temp_var_data, x_dims[0],
x_dims[1], imsize, groups, group_size, epsilon, d_x_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
group_norm,
ops::GroupNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::GroupNormKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
group_norm_grad,
ops::GroupNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::GroupNormGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
template <typename DeviceContext, typename T>
class GroupNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Y");
auto* mean = ctx.Output<Tensor>("Mean");
auto* var = ctx.Output<Tensor>("Variance");
const auto groups = ctx.Attr<int>("groups");
const auto x_dims = x->dims();
const int group_size = (x_dims[1] - 1) / groups + 1;
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto* x_data = x->data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>();
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = x_dims[2] * x_dims[3];
auto* iter_x_data = x_data;
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++)
for (int gid = 0; gid < groups; gid++) {
T x_mean = 0, x_var = 0;
int number = std::min(group_size,
static_cast<int>(x_dims[1] - gid * group_size));
auto* tmp = iter_x_data;
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; imid++, iter_x_data++) {
x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0];
}
}
x_mean /= number * imsize;
x_var /= number * imsize;
x_var = x_var - x_mean * x_mean;
T var_inv = 1.0 / sqrt(x_var + epsilon);
mean_data[bid * groups + gid] = x_mean;
var_data[bid * groups + gid] = x_var;
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; imid++, tmp++, iter_y_data++) {
T val = (tmp[0] - x_mean) * var_inv;
if (scale_data) val *= scale_data[gid * group_size + cid];
if (bias_data) val += bias_data[gid * group_size + cid];
iter_y_data[0] = val;
}
}
}
}
};
template <typename DeviceContext, typename T>
class GroupNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto* x = ctx.Input<Tensor>("X");
auto* mean = ctx.Input<Tensor>("Mean");
auto* var = ctx.Input<Tensor>("Variance");
auto* scale = ctx.Input<Tensor>("Scale");
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto groups = ctx.Attr<int>("groups");
// init output
auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto& x_dims = x->dims();
const int group_size = (x_dims[1] - 1) / groups + 1;
// TODO(liangdun): need to check d_x is null
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
T* d_x_data = nullptr;
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_x, static_cast<T>(0));
d_x_data = d_x->data<T>();
}
auto* x_data = x->data<T>();
auto* y_data = d_y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
T* d_scale_data = nullptr;
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_scale, static_cast<T>(0));
d_scale_data = d_scale->data<T>();
}
T* d_bias_data = nullptr;
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_bias, static_cast<T>(0));
d_bias_data = d_bias->data<T>();
}
const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>();
int imsize = x_dims[2] * x_dims[3];
auto* iter_x_data = x_data;
auto* iter_d_x_data = d_x_data;
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++)
for (int gid = 0; gid < groups; gid++) {
T x_mean = mean_data[bid * groups + gid];
T x_var = var_data[bid * groups + gid];
T var_inv = 1.0 / sqrt(x_var + epsilon);
int number = std::min(group_size,
static_cast<int>(x_dims[1] - gid * group_size));
auto* tmp = iter_x_data;
auto* tmp2 = iter_d_x_data;
T d_var_inv = 0, d_x_mean = 0;
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize;
imid++, tmp++, iter_y_data++, iter_d_x_data++) {
T val = (tmp[0] - x_mean) * var_inv;
T dval = iter_y_data[0];
if (d_bias_data) d_bias_data[gid * group_size + cid] += dval;
if (d_scale_data)
d_scale_data[gid * group_size + cid] += val * dval;
if (scale_data) dval = scale_data[gid * group_size + cid] * dval;
d_var_inv += (tmp[0] - x_mean) * dval;
T d_tmp = dval * var_inv;
if (d_x_data) iter_d_x_data[0] += d_tmp;
d_x_mean -= d_tmp;
}
}
T d_x_var =
-1.0 / (2 * (x_var + epsilon) * sqrt(x_var + epsilon)) * d_var_inv;
d_x_mean -= 2 * d_x_var * x_mean;
d_x_var /= number * imsize;
d_x_mean /= number * imsize;
iter_d_x_data = tmp2;
if (d_x_data) {
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize;
imid++, iter_x_data++, iter_d_x_data++) {
iter_d_x_data[0] += d_x_mean;
iter_d_x_data[0] += iter_x_data[0] * 2 * d_x_var;
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -85,6 +85,7 @@ __all__ = [
'row_conv',
'multiplex',
'layer_norm',
'group_norm',
'softmax_with_cross_entropy',
'smooth_l1',
'one_hot',
......@@ -2547,6 +2548,84 @@ def layer_norm(input,
return helper.append_activation(layer_norm_out)
@templatedoc()
def group_norm(input,
groups,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
act=None,
data_layout='NCHW',
name=None):
"""
**Group Normalization Layer**
Refer to `Group Normalization <https://arxiv.org/abs/1803.08494>`
Args:
input(Variable): The input tensor variable.
groups(int): The number of groups that divided from channels.
epsilon(float): The small value added to the variance to prevent
division by zero.
param_attr(ParamAttr|None): The parameter attribute for the learnable
scale :math:`g`. If it is set to False, no scale will be added to the output units.
If it is set to None, the bias is initialized one. Default: None.
bias_attr(ParamAttr|None): The parameter attribute for the learnable
bias :math:`b`. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. Default: None.
act(str): Activation to be applied to the output of group normalizaiton.
data_layout(string|NCHW): Only NCHW is supported.
name (str): The name of this layer. It is optional.
Returns:
Variable: A tensor variable which is the result after applying group normalization on the input.
Examples:
>>> data = fluid.layers.data(name='data', shape=[8, 32, 32],
>>> dtype='float32')
>>> x = fluid.layers.group_norm(input=data, groups=4)
"""
helper = LayerHelper('group_norm', **locals())
dtype = helper.input_dtype()
# create intput and parameters
inputs = {'X': input}
input_shape = input.shape
if data_layout != 'NCHW':
raise ValueError("unsupported data layout:" + data_layout)
param_shape = [input_shape[1]]
if param_attr:
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0))
inputs['Scale'] = scale
if bias_attr:
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
inputs['Bias'] = bias
# create output
mean_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
variance_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
group_norm_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="group_norm",
inputs=inputs,
outputs={
"Y": group_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={"epsilon": epsilon,
"groups": groups})
return helper.append_activation(group_norm_out)
def conv2d_transpose(input,
num_filters,
output_size=None,
......
......@@ -381,8 +381,8 @@ class OpTest(unittest.TestCase):
outs.sort(key=len)
checker(outs)
def __assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix):
def _assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix):
for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names):
abs_a = np.abs(a)
......@@ -451,9 +451,9 @@ class OpTest(unittest.TestCase):
analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set)
self.__assert_is_close(numeric_grads, analytic_grads, inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place))
self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place))
@staticmethod
def _numpy_to_lod_tensor(np_value, lod, place):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from operator import mul
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test import OpTest
from testsuite import create_op
def group_norm_naive(x, scale, bias, epsilon, groups):
N, C, H, W = x.shape
G = groups
x = x.reshape((N * G, -1))
mean = np.mean(x, axis=1, keepdims=True)
var = np.var(x, axis=1, keepdims=True)
output = (x - mean) / np.sqrt(var + epsilon)
output = output.reshape((N, C, H, W)) * scale.reshape(
(-1, 1, 1)) + bias.reshape((-1, 1, 1))
return output, mean.reshape((N, G)), var.reshape((N, G))
class TestGroupNormOp(OpTest):
def setUp(self):
self.op_type = "group_norm"
self.data_format = "NCHW"
self.dtype = np.float32
self.shape = (2, 4, 3, 3)
self.attrs = {'epsilon': 1e-5, 'groups': 2}
self.compare_between_place = False
self.init_test_case()
input = np.random.random(self.shape).astype(self.dtype)
scale = np.random.random([self.shape[1]]).astype(self.dtype)
bias = np.random.random([self.shape[1]]).astype(self.dtype)
output, mean, var = group_norm_naive(
input, scale, bias, self.attrs['epsilon'], self.attrs['groups'])
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(input),
'Scale': OpTest.np_dtype_to_fluid_dtype(scale),
'Bias': OpTest.np_dtype_to_fluid_dtype(bias)
}
self.outputs = {'Y': output, 'Mean': mean, 'Variance': var}
def test_check_output(self):
atol = 1e-4
place = core.CPUPlace()
self.check_output_with_place(place, atol=atol)
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=atol)
def do_compare_between_place(self):
if not core.is_compiled_with_cuda(): return
place = core.CPUPlace()
place2 = core.CUDAPlace(0)
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
op_attrs)
inputs_to_check = set(['X', 'Scale', 'Bias'])
output_names = 'Y'
cpu_grads = self._get_gradient(inputs_to_check, place, output_names,
None)
gpu_grads = self._get_gradient(inputs_to_check, place2, output_names,
None)
self._assert_is_close(cpu_grads, gpu_grads, inputs_to_check, 0.005,
"Gradient Check On %s" % str(place))
def test_check_grad(self):
if self.compare_between_place:
self.do_compare_between_place()
return
place = core.CPUPlace()
self.check_grad_with_place(
place, set(['X', 'Scale', 'Bias']), 'Y', max_relative_error=0.01)
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['X', 'Scale', 'Bias']),
'Y',
max_relative_error=0.01)
def init_test_case(self):
pass
class TestGroupNormOp1(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 1
class TestGroupNormOp2(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 4
class TestGroupNormOpBigEps1(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 1
self.attrs['epsilon'] = 0.5
class TestGroupNormOpBigEps2(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 4
self.attrs['epsilon'] = 0.5
class TestGroupNormOpBigEps3(TestGroupNormOp):
def init_test_case(self):
self.attrs['epsilon'] = 0.5
class TestGroupNormOpLargeData(TestGroupNormOp):
def init_test_case(self):
self.shape = (2, 32, 64, 64)
self.attrs['groups'] = 8
self.compare_between_place = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册