提交 de9bec60 编写于 作者: Z Zhang Ting 提交者: Aurelius84

lrn supports channel_last input, test=develop (#20954)

上级 9b666cae
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/lrn_op.h" #include "paddle/fluid/operators/lrn_op.h"
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -23,18 +25,41 @@ namespace paddle { ...@@ -23,18 +25,41 @@ namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T> template <typename T>
struct LRNFunctor<platform::CPUDeviceContext, T> { struct LRNFunctor<platform::CPUDeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx, void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out, const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n, framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) { T k, T alpha, T beta, const DataLayout data_layout) {
const T* idata = input.data<T>();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
T* odata = out->mutable_data<T>(place); math::Transpose<platform::CPUDeviceContext, T, 4> transpose;
T* mdata = mid->mutable_data<T>(place); auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
Tensor in_transpose, mid_transpose, out_transpose;
// if channel_last, transpose to channel_first
if (data_layout == DataLayout::kNHWC) {
auto in_dims = input.dims();
std::vector<int64_t> shape(
{in_dims[0], in_dims[3], in_dims[1], in_dims[2]});
in_transpose.mutable_data<T>(framework::make_ddim(shape), place);
mid_transpose.mutable_data<T>(framework::make_ddim(shape), place);
out_transpose.mutable_data<T>(framework::make_ddim(shape), place);
std::vector<int> axis = {0, 3, 1, 2};
transpose(dev_ctx, input, &in_transpose, axis);
} else {
in_transpose = input;
mid_transpose = *mid;
out_transpose = *out;
mid_transpose.mutable_data<T>(mid->dims(), place);
out_transpose.mutable_data<T>(out->dims(), place);
}
const T* idata = in_transpose.data<T>();
T* odata = out_transpose.data<T>();
T* mdata = mid_transpose.data<T>();
Tensor squared; Tensor squared;
T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place); T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place);
std::memset(sdata, 0, sizeof(T) * squared.numel()); std::memset(sdata, 0, sizeof(T) * squared.numel());
...@@ -67,6 +92,13 @@ struct LRNFunctor<platform::CPUDeviceContext, T> { ...@@ -67,6 +92,13 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
// compute the final output // compute the final output
blas.VPOW(mid->numel(), mdata, -beta, odata); blas.VPOW(mid->numel(), mdata, -beta, odata);
blas.VMUL(mid->numel(), odata, idata, odata); blas.VMUL(mid->numel(), odata, idata, odata);
// if channel_last, transpose the output(NCHW) to channel_last
if (data_layout == DataLayout::kNHWC) {
std::vector<int> axis = {0, 2, 3, 1};
transpose(dev_ctx, mid_transpose, mid, axis);
transpose(dev_ctx, out_transpose, out, axis);
}
} }
}; };
template struct LRNFunctor<platform::CPUDeviceContext, float>; template struct LRNFunctor<platform::CPUDeviceContext, float>;
...@@ -78,7 +110,7 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> { ...@@ -78,7 +110,7 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& x, const framework::Tensor& out, const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g, const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W, const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta) { int n, T alpha, T beta, const DataLayout data_layout) {
T ratio = -2 * alpha * beta; T ratio = -2 * alpha * beta;
auto x_g_e = framework::EigenVector<T>::Flatten(*x_g); auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
x_g_e = x_g_e.constant(0.0); x_g_e = x_g_e.constant(0.0);
...@@ -93,17 +125,17 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> { ...@@ -93,17 +125,17 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
const int end = start + n; const int end = start + n;
for (int m = 0; m < N; m++) { for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) { for (int i = 0; i < C; i++) {
auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}), auto offsets = Eigen::array<int, 4>({{m, i, 0, 0}});
Eigen::array<int, 4>({{1, 1, H, W}})); auto extents = Eigen::array<int, 4>({{1, 1, H, W}});
if (data_layout == DataLayout::kNHWC) {
auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}), offsets = Eigen::array<int, 4>({{m, 0, 0, i}});
Eigen::array<int, 4>({{1, 1, H, W}})); extents = Eigen::array<int, 4>({{1, H, W, 1}});
}
auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}), auto i_x = e_x.slice(offsets, extents);
Eigen::array<int, 4>({{1, 1, H, W}})); auto i_x_g = e_x_g.slice(offsets, extents);
auto i_out_g = e_out_g.slice(offsets, extents);
auto i_mid = e_mid.slice(offsets, extents);
i_x_g = i_mid.pow(-beta) * i_out_g; i_x_g = i_mid.pow(-beta) * i_out_g;
for (int c = start; c < end; c++) { for (int c = start; c < end; c++) {
...@@ -112,14 +144,14 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> { ...@@ -112,14 +144,14 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
continue; continue;
} }
auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}), if (data_layout != DataLayout::kNHWC) {
Eigen::array<int, 4>({{1, 1, H, W}})); offsets = Eigen::array<int, 4>({{m, ch, 0, 0}});
} else {
auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}), offsets = Eigen::array<int, 4>({{m, 0, 0, ch}});
Eigen::array<int, 4>({{1, 1, H, W}})); }
auto c_out = e_out.slice(offsets, extents);
auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}), auto c_mid = e_mid.slice(offsets, extents);
Eigen::array<int, 4>({{1, 1, H, W}})); auto c_out_g = e_out_g.slice(offsets, extents);
i_x_g += ratio * c_out_g * c_out * i_x / c_mid; i_x_g += ratio * c_out_g * c_out * i_x / c_mid;
} }
...@@ -156,9 +188,8 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -156,9 +188,8 @@ class LRNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
...@@ -242,8 +273,8 @@ $$ ...@@ -242,8 +273,8 @@ $$
Function implementation: Function implementation:
Inputs and outpus are in NCHW format, while input.shape.ndims() equals 4. Inputs and outpus are in NCHW or NHWC format, while input.shape.ndims() equals 4.
And dimensions 0 ~ 3 represent batch size, feature maps, rows, If NCHW, the dimensions 0 ~ 3 represent batch size, feature maps, rows,
and columns, respectively. and columns, respectively.
Input and Output in the formula above is for each map(i) of one image, and Input and Output in the formula above is for each map(i) of one image, and
...@@ -275,9 +306,8 @@ class LRNOpGrad : public framework::OperatorWithKernel { ...@@ -275,9 +306,8 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
......
...@@ -17,15 +17,20 @@ limitations under the License. */ ...@@ -17,15 +17,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using DataLayout = framework::DataLayout;
template <typename T> template <typename T>
__global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C, __global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C,
int H, int W, int size, T k, T alpha) { int H, int W, int size, T k, T alpha,
const DataLayout data_layout) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < img_size) { if (idx < img_size) {
const int w = idx % W; const int w = idx % W;
const int h = (idx / W) % H; const int h = (idx / W) % H;
const int n = idx / W / H; const int n = idx / W / H;
const int offset = (n * C * H + h) * W + w; const int offset =
(data_layout != DataLayout::kNHWC ? (n * C * H + h) * W + w
: ((n * H + h) * W + w) * C);
in += offset; in += offset;
mid += offset; mid += offset;
...@@ -37,15 +42,21 @@ __global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C, ...@@ -37,15 +42,21 @@ __global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C,
int index = 0; int index = 0;
while (index < C + post_pad) { while (index < C + post_pad) {
if (index < C) { if (index < C) {
T val = in[index * step]; int in_idx = (data_layout != DataLayout::kNHWC ? index * step : index);
T val = in[in_idx];
accum += val * val; accum += val * val;
} }
if (index >= size) { if (index >= size) {
T val = in[(index - size) * step]; int in_idx = (data_layout != DataLayout::kNHWC ? (index - size) * step
: index - size);
T val = in[in_idx];
accum -= val * val; accum -= val * val;
} }
if (index >= post_pad) { if (index >= post_pad) {
mid[(index - post_pad) * step] = k + accum * alpha; int mid_idx =
(data_layout != DataLayout::kNHWC ? (index - post_pad) * step
: index - post_pad);
mid[mid_idx] = k + accum * alpha;
} }
++index; ++index;
} }
...@@ -64,14 +75,14 @@ __global__ void KeCMRNormOutput(int input_size, const T* in, const T* mid, ...@@ -64,14 +75,14 @@ __global__ void KeCMRNormOutput(int input_size, const T* in, const T* mid,
template <typename T> template <typename T>
void CrossMapNormal(const framework::ExecutionContext& ctx, const T* inputs, void CrossMapNormal(const framework::ExecutionContext& ctx, const T* inputs,
T* outputs, T* mid, int N, int C, int H, int W, int n, T k, T* outputs, T* mid, int N, int C, int H, int W, int n, T k,
T alpha, T beta) { T alpha, T beta, const DataLayout data_layout) {
int img_size = N * H * W; int img_size = N * H * W;
const int block_size = 1024; const int block_size = 1024;
int grid_size = (img_size + block_size - 1) / block_size; int grid_size = (img_size + block_size - 1) / block_size;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
KeCMRNormFillScale<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>( KeCMRNormFillScale<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
img_size, inputs, mid, C, H, W, n, k, alpha); img_size, inputs, mid, C, H, W, n, k, alpha, data_layout);
int input_size = N * H * W * C; int input_size = N * H * W * C;
grid_size = (input_size + block_size - 1) / block_size; grid_size = (input_size + block_size - 1) / block_size;
...@@ -84,10 +95,11 @@ struct LRNFunctor<platform::CUDADeviceContext, T> { ...@@ -84,10 +95,11 @@ struct LRNFunctor<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx, void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out, const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n, framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) { T k, T alpha, T beta, const DataLayout data_layout) {
CrossMapNormal<T>( CrossMapNormal<T>(ctx, input.data<T>(),
ctx, input.data<T>(), out->mutable_data<T>(ctx.GetPlace()), out->mutable_data<T>(ctx.GetPlace()),
mid->mutable_data<T>(ctx.GetPlace()), N, C, H, W, n, k, alpha, beta); mid->mutable_data<T>(ctx.GetPlace()), N, C, H, W, n, k,
alpha, beta, data_layout);
} }
}; };
...@@ -97,14 +109,16 @@ template struct LRNFunctor<platform::CUDADeviceContext, double>; ...@@ -97,14 +109,16 @@ template struct LRNFunctor<platform::CUDADeviceContext, double>;
template <typename T> template <typename T>
__global__ void KeCMRNormDiff(int img_size, const T* x, const T* out, __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
const T* mid, T* x_g, const T* out_g, int C, const T* mid, T* x_g, const T* out_g, int C,
int H, int W, int size, T negative_beta, int H, int W, int size, T negative_beta, T ratio,
T ratio) { const DataLayout data_layout) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < img_size) { if (idx < img_size) {
const int w = idx % W; const int w = idx % W;
const int h = (idx / W) % H; const int h = (idx / W) % H;
const int n = idx / W / H; const int n = idx / W / H;
const int offset = (n * C * H + h) * W + w; const int offset =
(data_layout != DataLayout::kNHWC ? (n * C * H + h) * W + w
: ((n * H + h) * W + w) * C);
x += offset; x += offset;
out += offset; out += offset;
mid += offset; mid += offset;
...@@ -120,18 +134,20 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out, ...@@ -120,18 +134,20 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
// TODO(gongwb): optimize this with thread shared array. // TODO(gongwb): optimize this with thread shared array.
while (index < C + post_pad) { while (index < C + post_pad) {
if (index < C) { if (index < C) {
x_g[index * step] = 0.0; int idx = (data_layout != DataLayout::kNHWC ? index * step : index);
accum += out_g[index * step] * out[index * step] / mid[index * step]; x_g[idx] = 0.0;
accum += out_g[idx] * out[idx] / mid[idx];
} }
if (index >= size) { if (index >= size) {
accum -= out_g[(index - size) * step] * out[(index - size) * step] / int idx = (data_layout != DataLayout::kNHWC ? (index - size) * step
mid[(index - size) * step]; : index - size);
accum -= out_g[idx] * out[idx] / mid[idx];
} }
if (index >= post_pad) { if (index >= post_pad) {
x_g[(index - post_pad) * step] += int idx = (data_layout != DataLayout::kNHWC ? (index - post_pad) * step
out_g[(index - post_pad) * step] * : index - post_pad);
pow(mid[(index - post_pad) * step], negative_beta) - x_g[idx] +=
ratio * x[(index - post_pad) * step] * accum; out_g[idx] * pow(mid[idx], negative_beta) - ratio * x[idx] * accum;
} }
++index; ++index;
} }
...@@ -141,7 +157,8 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out, ...@@ -141,7 +157,8 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
template <typename T> template <typename T>
void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x, void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x,
const T* out, const T* mid, T* x_g, const T* out_g, const T* out, const T* mid, T* x_g, const T* out_g,
int N, int C, int H, int W, int n, T alpha, T beta) { int N, int C, int H, int W, int n, T alpha, T beta,
const DataLayout data_layout) {
int img_size = N * H * W; int img_size = N * H * W;
const int block_size = 1024; const int block_size = 1024;
...@@ -149,8 +166,8 @@ void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x, ...@@ -149,8 +166,8 @@ void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x,
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
KeCMRNormDiff<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>( KeCMRNormDiff<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta, img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta, 2.0f * alpha * beta,
2.0f * alpha * beta); data_layout);
} }
template <typename T> template <typename T>
...@@ -159,10 +176,10 @@ struct LRNGradFunctor<platform::CUDADeviceContext, T> { ...@@ -159,10 +176,10 @@ struct LRNGradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& x, const framework::Tensor& out, const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g, const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W, const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta) { int n, T alpha, T beta, const DataLayout data_layout) {
CrossMapNormalGrad<T>(ctx, x.data<T>(), out.data<T>(), mid.data<T>(), CrossMapNormalGrad<T>(ctx, x.data<T>(), out.data<T>(), mid.data<T>(),
x_g->mutable_data<T>(ctx.GetPlace()), out_g.data<T>(), x_g->mutable_data<T>(ctx.GetPlace()), out_g.data<T>(),
N, C, H, W, n, alpha, beta); N, C, H, W, n, alpha, beta, data_layout);
} }
}; };
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -21,12 +23,15 @@ limitations under the License. */ ...@@ -21,12 +23,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using DataLayout = framework::DataLayout;
template <typename place, typename T> template <typename place, typename T>
struct LRNFunctor { struct LRNFunctor {
void operator()(const framework::ExecutionContext& ctx, void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out, const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n, framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta); T k, T alpha, T beta,
const DataLayout data_layout = DataLayout::kAnyLayout);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -42,11 +47,14 @@ class LRNKernel : public framework::OpKernel<T> { ...@@ -42,11 +47,14 @@ class LRNKernel : public framework::OpKernel<T> {
const Tensor& x = *ctx.Input<Tensor>("X"); const Tensor& x = *ctx.Input<Tensor>("X");
auto x_dims = x.dims(); auto x_dims = x.dims();
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
// NCHW // NCHW
int N = x_dims[0]; int N = x_dims[0];
int C = x_dims[1]; int C = (data_layout != DataLayout::kNHWC ? x_dims[1] : x_dims[3]);
int H = x_dims[2]; int H = (data_layout != DataLayout::kNHWC ? x_dims[2] : x_dims[1]);
int W = x_dims[3]; int W = (data_layout != DataLayout::kNHWC ? x_dims[3] : x_dims[2]);
Tensor* out = ctx.Output<Tensor>("Out"); Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
...@@ -65,7 +73,7 @@ class LRNKernel : public framework::OpKernel<T> { ...@@ -65,7 +73,7 @@ class LRNKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0"); PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");
LRNFunctor<DeviceContext, T> f; LRNFunctor<DeviceContext, T> f;
f(ctx, x, out, mid, N, C, H, W, n, k, alpha, beta); f(ctx, x, out, mid, N, C, H, W, n, k, alpha, beta, data_layout);
} }
}; };
...@@ -75,7 +83,8 @@ struct LRNGradFunctor { ...@@ -75,7 +83,8 @@ struct LRNGradFunctor {
const framework::Tensor& x, const framework::Tensor& out, const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g, const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W, const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta); int n, T alpha, T beta,
const DataLayout data_layout = DataLayout::kAnyLayout);
}; };
/** /**
...@@ -106,15 +115,18 @@ class LRNGradKernel : public framework::OpKernel<T> { ...@@ -106,15 +115,18 @@ class LRNGradKernel : public framework::OpKernel<T> {
const Tensor& out = *ctx.Input<Tensor>("Out"); const Tensor& out = *ctx.Input<Tensor>("Out");
const Tensor& out_g = *ctx.Input<Tensor>(framework::GradVarName("Out")); const Tensor& out_g = *ctx.Input<Tensor>(framework::GradVarName("Out"));
const Tensor& mid = *ctx.Input<Tensor>("MidOut"); const Tensor& mid = *ctx.Input<Tensor>("MidOut");
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
auto x_g = ctx.Output<Tensor>(framework::GradVarName("X")); auto x_g = ctx.Output<Tensor>(framework::GradVarName("X"));
x_g->mutable_data<T>(ctx.GetPlace()); x_g->mutable_data<T>(ctx.GetPlace());
auto x_dims = x.dims(); auto x_dims = x.dims();
int N = x_dims[0]; int N = x_dims[0];
int C = x_dims[1]; int C = (data_layout != DataLayout::kNHWC ? x_dims[1] : x_dims[3]);
int H = x_dims[2]; int H = (data_layout != DataLayout::kNHWC ? x_dims[2] : x_dims[1]);
int W = x_dims[3]; int W = (data_layout != DataLayout::kNHWC ? x_dims[3] : x_dims[2]);
int n = ctx.Attr<int>("n"); int n = ctx.Attr<int>("n");
T alpha = ctx.Attr<T>("alpha"); T alpha = ctx.Attr<T>("alpha");
...@@ -125,7 +137,7 @@ class LRNGradKernel : public framework::OpKernel<T> { ...@@ -125,7 +137,7 @@ class LRNGradKernel : public framework::OpKernel<T> {
"is_test attribute should be set to False in training phase."); "is_test attribute should be set to False in training phase.");
LRNGradFunctor<DeviceContext, T> f; LRNGradFunctor<DeviceContext, T> f;
f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta); f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta, data_layout);
} }
}; };
......
...@@ -9277,7 +9277,8 @@ def lod_append(x, level): ...@@ -9277,7 +9277,8 @@ def lod_append(x, level):
return out return out
def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None,
data_format='NCHW'):
""" """
This operator implements the Local Response Normalization Layer. This operator implements the Local Response Normalization Layer.
This layer performs a type of "lateral inhibition" by normalizing over local input regions. This layer performs a type of "lateral inhibition" by normalizing over local input regions.
...@@ -9298,13 +9299,18 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): ...@@ -9298,13 +9299,18 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
Args: Args:
input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W], where N is the batch size, C is the input channel, H is Height, W is weight. The data type is float32. The rank of this tensor must be 4, otherwise it will raise ValueError. input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W] or [N, H, W, C],
where N is the batch size, C is the input channel, H is Height, W is weight. The data
type is float32. The rank of this tensor must be 4, otherwise it will raise ValueError.
n (int, optional): The number of channels to sum over. Default: 5 n (int, optional): The number of channels to sum over. Default: 5
k (float, optional): An offset, positive. Default: 1.0 k (float, optional): An offset, positive. Default: 1.0
alpha (float, optional): The scaling parameter, positive. Default:1e-4 alpha (float, optional): The scaling parameter, positive. Default:1e-4
beta (float, optional): The exponent, positive. Default:0.75 beta (float, optional): The exponent, positive. Default:0.75
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`
data_format(str, optional): The data format of the input and output data. An optional string
from: `"NCHW"`, `"NHWC"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. Default: 'NCHW'.
Returns: Returns:
Variable: A tensor variable storing the transformation result with the same shape and data type as input. Variable: A tensor variable storing the transformation result with the same shape and data type as input.
...@@ -9327,8 +9333,12 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): ...@@ -9327,8 +9333,12 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
if dims != 4: if dims != 4:
raise ValueError( raise ValueError(
"dims of input must be 4(not %d), and it's order must be NCHW" % "Input's dimension size of Op(lrn) must be 4, but received %d." %
(dims)) (dims))
if data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Attr(data_format) of Op(lrn) got wrong value: received " +
data_format + " but only NCHW or NHWC supported.")
mid_out = helper.create_variable_for_type_inference( mid_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
...@@ -9340,10 +9350,13 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): ...@@ -9340,10 +9350,13 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
"Out": lrn_out, "Out": lrn_out,
"MidOut": mid_out, "MidOut": mid_out,
}, },
attrs={"n": n, attrs={
"k": k, "n": n,
"alpha": alpha, "k": k,
"beta": beta}) "alpha": alpha,
"beta": beta,
"data_format": data_format
})
return lrn_out return lrn_out
......
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
...@@ -60,12 +62,15 @@ class TestLRNOp(OpTest): ...@@ -60,12 +62,15 @@ class TestLRNOp(OpTest):
'n': self.n, 'n': self.n,
'k': self.k, 'k': self.k,
'alpha': self.alpha, 'alpha': self.alpha,
'beta': self.beta 'beta': self.beta,
'data_format': self.data_format
} }
return attrs return attrs
def setUp(self): def setUp(self):
self.op_type = "lrn" self.op_type = "lrn"
self.init_test_case()
self.N = 2 self.N = 2
self.C = 3 self.C = 3
self.H = 5 self.H = 5
...@@ -77,11 +82,18 @@ class TestLRNOp(OpTest): ...@@ -77,11 +82,18 @@ class TestLRNOp(OpTest):
self.beta = 0.75 self.beta = 0.75
self.x = self.get_input() self.x = self.get_input()
self.out, self.mid_out = self.get_out() self.out, self.mid_out = self.get_out()
if self.data_format == 'NHWC':
self.x = np.transpose(self.x, [0, 2, 3, 1])
self.out = np.transpose(self.out, [0, 2, 3, 1])
self.mid_out = np.transpose(self.mid_out, [0, 2, 3, 1])
self.inputs = {'X': self.x} self.inputs = {'X': self.x}
self.outputs = {'Out': self.out, 'MidOut': self.mid_out} self.outputs = {'Out': self.out, 'MidOut': self.mid_out}
self.attrs = self.get_attrs() self.attrs = self.get_attrs()
def init_test_case(self):
self.data_format = 'NCHW'
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -89,5 +101,49 @@ class TestLRNOp(OpTest): ...@@ -89,5 +101,49 @@ class TestLRNOp(OpTest):
self.check_grad(['X'], 'Out', max_relative_error=0.01) self.check_grad(['X'], 'Out', max_relative_error=0.01)
class TestLRNOpAttrDataFormat(TestLRNOp):
def init_test_case(self):
self.data_format = 'NHWC'
class TestLRNAPI(OpTest):
def test_case(self):
data1 = fluid.data(name='data1', shape=[2, 4, 5, 5], dtype='float32')
data2 = fluid.data(name='data2', shape=[2, 5, 5, 4], dtype='float32')
out1 = fluid.layers.lrn(data1, data_format='NCHW')
out2 = fluid.layers.lrn(data2, data_format='NHWC')
data1_np = np.random.random((2, 4, 5, 5)).astype("float32")
data2_np = np.transpose(data1_np, [0, 2, 3, 1])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2],
return_numpy=True)
self.assertTrue(
np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2))))
def test_exception(self):
input1 = fluid.data(name="input1", shape=[2, 4, 5, 5], dtype="float32")
input2 = fluid.data(
name="input2", shape=[2, 4, 5, 5, 5], dtype="float32")
def _attr_data_fromat():
out = fluid.layers.lrn(input1, data_format='NDHW')
def _input_dim_size():
out = fluid.layers.lrn(input2)
self.assertRaises(ValueError, _attr_data_fromat)
self.assertRaises(ValueError, _input_dim_size)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册