未验证 提交 c7e739f5 编写于 作者: G gongweibao 提交者: GitHub

Add LRN efficient GPU implement. (#5894)

Add LRN efficient GPU implement
上级 1d1555e2
......@@ -19,6 +19,103 @@ namespace operators {
using framework::Tensor;
template <typename T>
struct LRNFunctor<platform::CPUPlace, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) {
auto x_v = framework::EigenVector<T>::Flatten(input);
const int start = -(n - 1) / 2;
const int end = start + n;
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
auto e_x = framework::EigenTensor<T, 4>::From(input);
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
for (int c = start; c <= end; c++) {
int ch = i + c;
if (ch >= 0 && ch < C) {
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
s += alpha * r.square();
}
}
}
}
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e = x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
}
};
template struct LRNFunctor<platform::CPUPlace, float>;
template struct LRNFunctor<platform::CPUPlace, double>;
template <typename T>
struct LRNGradFunctor<platform::CPUPlace, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta) {
T ratio = -2 * alpha * beta;
auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
x_g_e = x_g_e.constant(0.0);
auto e_x = framework::EigenTensor<T, 4>::From(x);
auto e_x_g = framework::EigenTensor<T, 4>::From(*x_g);
auto e_out = framework::EigenTensor<T, 4>::From(out);
auto e_out_g = framework::EigenTensor<T, 4>::From(out_g);
auto e_mid = framework::EigenTensor<T, 4>::From(mid);
const int start = -(n - 1) / 2;
const int end = start + n;
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
i_x_g = i_mid.pow(-beta) * i_out_g;
for (int c = start; c <= end; c++) {
int ch = i + c;
if (ch < 0 || ch >= C) {
continue;
}
auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
i_x_g += ratio * c_out_g * c_out * i_x / c_mid;
}
}
}
}
};
template struct LRNGradFunctor<platform::CPUPlace, float>;
template struct LRNGradFunctor<platform::CPUPlace, double>;
class LRNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -83,8 +180,8 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Local Response Normalization Operator.
This operator comes from the paper
"ImageNet Classification with Deep Convolutional Neural Networks".
This operator comes from the paper:
<<ImageNet Classification with Deep Convolutional Neural Networks>>.
The original formula is:
......@@ -119,8 +216,7 @@ class LRNOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("MidOut")),
"Input(MidOut@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput("MidOut"), "Input(MidOut) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
......
......@@ -12,11 +12,167 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/lrn_op.h"
namespace ops = paddle::operators;
namespace paddle {
namespace operators {
template <typename T>
__global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C,
int H, int W, int size, T k, T alpha) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < img_size) {
const int w = idx % W;
const int h = (idx / W) % H;
const int n = idx / W / H;
const int offset = (n * C * H + h) * W + w;
in += offset;
mid += offset;
const int step = H * W;
const int pre_pad = (size - 1) / 2;
const int post_pad = size - pre_pad - 1;
T accum = 0;
int index = 0;
while (index < C + post_pad) {
if (index < C) {
T val = in[index * step];
accum += val * val;
}
if (index >= size) {
T val = in[(index - size) * step];
accum -= val * val;
}
if (index >= post_pad) {
mid[(index - post_pad) * step] = k + accum * alpha;
}
++index;
}
}
}
template <typename T>
__global__ void KeCMRNormOutput(int input_size, const T* in, const T* mid,
T negative_beta, T* out) {
const int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < input_size) {
out[index] = in[index] * pow(mid[index], negative_beta);
}
}
template <typename T>
void CrossMapNormal(const framework::ExecutionContext& ctx, const T* inputs,
T* outputs, T* mid, int N, int C, int H, int W, int n, T k,
T alpha, T beta) {
int img_size = N * H * W;
const int block_size = 1024;
int grid_size = (img_size + block_size - 1) / block_size;
KeCMRNormFillScale<
T><<<grid_size, block_size, 0, ctx.cuda_device_context().stream()>>>(
img_size, inputs, mid, C, H, W, n, k, alpha);
int input_size = N * H * W * C;
grid_size = (input_size + block_size - 1) / block_size;
KeCMRNormOutput<
T><<<grid_size, block_size, 0, ctx.cuda_device_context().stream()>>>(
input_size, inputs, mid, -beta, outputs);
}
template <typename T>
struct LRNFunctor<platform::GPUPlace, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) {
CrossMapNormal<T>(
ctx, input.data<T>(), out->mutable_data<T>(ctx.GetPlace()),
mid->mutable_data<T>(ctx.GetPlace()), N, C, H, W, n, k, alpha, beta);
}
};
template struct LRNFunctor<platform::GPUPlace, float>;
template struct LRNFunctor<platform::GPUPlace, double>;
template <typename T>
__global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
const T* mid, T* x_g, const T* out_g, int C,
int H, int W, int size, T negative_beta,
T ratio) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < img_size) {
const int w = idx % W;
const int h = (idx / W) % H;
const int n = idx / W / H;
const int offset = (n * C * H + h) * W + w;
x += offset;
out += offset;
mid += offset;
out_g += offset;
x_g += offset;
const int step = H * W;
const int pre_pad = size - (size + 1) / 2;
const int post_pad = size - pre_pad - 1;
int index = 0;
T accum = 0;
// TODO(gongwb): optimize this with thread shared array.
while (index < C + post_pad) {
if (index < C) {
x_g[index * step] = 0.0;
accum += out_g[index * step] * out[index * step] / mid[index * step];
}
if (index >= size) {
accum -= out_g[(index - size) * step] * out[(index - size) * step] /
mid[(index - size) * step];
}
if (index >= post_pad) {
x_g[(index - post_pad) * step] +=
out_g[(index - post_pad) * step] *
pow(mid[(index - post_pad) * step], negative_beta) -
ratio * x[(index - post_pad) * step] * accum;
}
++index;
}
}
}
template <typename T>
void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x,
const T* out, const T* mid, T* x_g, const T* out_g,
int N, int C, int H, int W, int n, T alpha, T beta) {
int img_size = N * H * W;
const int block_size = 1024;
int grid_size = (img_size + block_size - 1) / block_size;
KeCMRNormDiff<
T><<<grid_size, block_size, 0, ctx.cuda_device_context().stream()>>>(
img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta,
2.0f * alpha * beta);
}
template <typename T>
struct LRNGradFunctor<platform::GPUPlace, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta) {
CrossMapNormalGrad<T>(ctx, x.data<T>(), out.data<T>(), mid.data<T>(),
x_g->mutable_data<T>(ctx.GetPlace()), out_g.data<T>(),
N, C, H, W, n, alpha, beta);
}
};
template struct LRNGradFunctor<platform::GPUPlace, float>;
template struct LRNGradFunctor<platform::GPUPlace, double>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lrn, ops::LRNKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(lrn_grad,
ops::LRNGradKernel<paddle::platform::GPUPlace, float>);
......@@ -21,6 +21,14 @@
namespace paddle {
namespace operators {
template <typename place, typename T>
struct LRNFunctor {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta);
};
template <typename Place, typename T>
class LRNKernel : public framework::OpKernel<T> {
public:
......@@ -31,8 +39,8 @@ class LRNKernel : public framework::OpKernel<T> {
// f(x) represents outputs
void Compute(const framework::ExecutionContext& ctx) const override {
// input
const Tensor* x = ctx.Input<Tensor>("X");
auto x_dims = x->dims();
const Tensor& x = *ctx.Input<Tensor>("X");
auto x_dims = x.dims();
// NCHW
int N = x_dims[0];
......@@ -57,38 +65,20 @@ class LRNKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0");
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");
auto x_v = framework::EigenVector<T>::Flatten(*x);
const int start = -(n - 1) / 2;
const int end = start + n;
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid.device(ctx.GetEigenDevice<Place>()) = e_mid.constant(k);
auto e_x = framework::EigenTensor<T, 4>::From(*x);
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
for (int c = start; c <= end; c++) {
int ch = i + c;
if (ch >= 0 && ch < C) {
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
s.device(ctx.GetEigenDevice<Place>()) += alpha * r.square();
}
}
}
}
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(ctx.GetEigenDevice<Place>()) =
x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
LRNFunctor<Place, T> f;
f(ctx, x, out, mid, N, C, H, W, n, k, alpha, beta);
}
};
template <typename Place, typename T>
struct LRNGradFunctor {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta);
};
/**
* \brief Backward calculation for normalization with across maps.
*
......@@ -97,7 +87,7 @@ class LRNKernel : public framework::OpKernel<T> {
* The implementation of this Function is derived from the
* CrossMapNormalFunc implementation.
*
* InputGrad = OutputGrad * denoms ^ (-beta)
* InputGrad = OutputGrad * MidOut ^ (-beta)
* -- upper
* + > (OutputGrad * OutputValue * (-2 * alpha * beta) / MidOut) * InputValue
* -- lower
......@@ -113,18 +103,15 @@ class LRNGradKernel : public framework::OpKernel<T> {
public:
using Tensor = framework::Tensor;
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* out = ctx.Input<Tensor>("Out");
const Tensor* out_g = ctx.Input<Tensor>(framework::GradVarName("Out"));
const Tensor* mid = ctx.Input<Tensor>("MidOut");
const Tensor& x = *ctx.Input<Tensor>("X");
const Tensor& out = *ctx.Input<Tensor>("Out");
const Tensor& out_g = *ctx.Input<Tensor>(framework::GradVarName("Out"));
const Tensor& mid = *ctx.Input<Tensor>("MidOut");
auto x_g = ctx.Output<Tensor>(framework::GradVarName("X"));
x_g->mutable_data<T>(ctx.GetPlace());
auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
x_g_e.device(ctx.GetEigenDevice<Place>()) = x_g_e.constant(0.0);
auto x_dims = x->dims();
auto x_dims = x.dims();
int N = x_dims[0];
int C = x_dims[1];
int H = x_dims[2];
......@@ -133,51 +120,9 @@ class LRNGradKernel : public framework::OpKernel<T> {
int n = ctx.Attr<int>("n");
T alpha = ctx.Attr<T>("alpha");
T beta = ctx.Attr<T>("beta");
T ratio = -2 * alpha * beta;
auto e_x = framework::EigenTensor<T, 4>::From(*x);
auto e_x_g = framework::EigenTensor<T, 4>::From(*x_g);
auto e_out = framework::EigenTensor<T, 4>::From(*out);
auto e_out_g = framework::EigenTensor<T, 4>::From(*out_g);
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
const int start = -(n - 1) / 2;
const int end = start + n;
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
i_x_g.device(ctx.GetEigenDevice<Place>()) = i_mid.pow(-beta) * i_out_g;
for (int c = start; c <= end; c++) {
int ch = i + c;
if (ch < 0 || ch >= C) {
continue;
}
auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
i_x_g.device(ctx.GetEigenDevice<Place>()) +=
ratio * c_out_g * c_out * i_x / c_mid;
}
}
}
LRNGradFunctor<Place, T> f;
f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta);
}
};
......
......@@ -23,7 +23,7 @@ class TestLRNOp(OpTest):
start = -(self.n - 1) / 2
end = start + self.n
mid = np.empty((self.N, self.C, self.H, self.W), dtype=float)
mid = np.empty((self.N, self.C, self.H, self.W)).astype("float32")
mid.fill(self.k)
for m in range(0, self.N):
for i in range(0, self.C):
......@@ -74,5 +74,4 @@ class TestLRNOp(OpTest):
if __name__ == "__main__":
exit(0) # LRN grad implement wrong
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册