From 3848f7208f13eb6695b2fd04b192d04843d00c5f Mon Sep 17 00:00:00 2001 From: Zhang Ting <709968123@qq.com> Date: Mon, 25 Nov 2019 11:07:20 +0800 Subject: [PATCH] [cherry-pick] fix crop_tensor, maxout and lrn (#21302) * [cherry-pick] All elements in attr(shape) of crop_tensor can be -1 and int32/64 kernel registered (#20756) * All elements in attr(shape) of crop_tensor can be -1, test=develop, test=document_preview * fix the bug that attr(offsets) should be initialized, test=develop * [cherry-pick] maxout supports channel_last input (#20846) * maxout support channel_last input, test=develop * modified details of Input(X) and Attr(groups, axis) in doc, test=develop * [cherry-pick] lrn supports channel_last input, test=develop (#20954) --- paddle/fluid/operators/crop_tensor_op.cc | 31 +++- paddle/fluid/operators/crop_tensor_op.cu | 8 +- paddle/fluid/operators/crop_tensor_op.h | 35 ++-- paddle/fluid/operators/lrn_op.cc | 88 ++++++---- paddle/fluid/operators/lrn_op.cu | 71 +++++--- paddle/fluid/operators/lrn_op.h | 32 ++-- paddle/fluid/operators/math/maxouting.cc | 50 ++++-- paddle/fluid/operators/math/maxouting.cu | 79 +++++---- paddle/fluid/operators/math/maxouting.h | 6 +- paddle/fluid/operators/maxout_op.cc | 52 +++--- paddle/fluid/operators/maxout_op.h | 7 +- python/paddle/fluid/layers/nn.py | 159 ++++++++++++------ .../tests/unittests/test_crop_tensor_op.py | 117 +++++++++---- .../fluid/tests/unittests/test_lrn_op.py | 58 ++++++- .../fluid/tests/unittests/test_maxout_op.py | 54 +++++- 15 files changed, 583 insertions(+), 264 deletions(-) diff --git a/paddle/fluid/operators/crop_tensor_op.cc b/paddle/fluid/operators/crop_tensor_op.cc index 43fa27ef4b1..e4a314cea06 100644 --- a/paddle/fluid/operators/crop_tensor_op.cc +++ b/paddle/fluid/operators/crop_tensor_op.cc @@ -31,8 +31,9 @@ class CropTensorOp : public framework::OperatorWithKernel { "Input(X) of Op(crop_tensor) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, "Output(Out) of Op(crop_tensor) should not be null."); - + auto x_dim = ctx->GetInputDim("X"); auto shape = ctx->Attrs().Get>("shape"); + auto offsets = ctx->Attrs().Get>("offsets"); if (ctx->HasInputs("ShapeTensor")) { // top prority shape auto inputs_name = ctx->Inputs("ShapeTensor"); @@ -43,15 +44,19 @@ class CropTensorOp : public framework::OperatorWithKernel { "Op(fluid.layers.crop_tensor)."); auto out_dims = std::vector(inputs_name.size(), -1); for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] != -1) { + if (shape[i] > 0) { out_dims[i] = static_cast(shape[i]); + } else { + if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) { + out_dims[i] = x_dim[i] - static_cast(offsets[i]); + } } } ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); return; } - auto x_dim = ctx->GetInputDim("X"); + if (ctx->HasInput("Shape")) { auto shape_dim = ctx->GetInputDim("Shape"); PADDLE_ENFORCE_EQ( @@ -78,11 +83,17 @@ class CropTensorOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(int64_t(shape.size()), x_dim.size(), "Attr(shape)'size of Op(crop_tensor) should be equal to " "dimention size of input tensor."); - std::vector tensor_shape(shape.size()); + std::vector out_shape(shape.size(), -1); for (size_t i = 0; i < shape.size(); ++i) { - tensor_shape[i] = static_cast(shape[i]); + if (shape[i] > 0) { + out_shape[i] = static_cast(shape[i]); + } else { + if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) { + out_shape[i] = x_dim[i] - static_cast(offsets[i]); + } + } } - ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape)); + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); } framework::OpKernelType GetExpectedKernelType( @@ -294,8 +305,12 @@ REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad); REGISTER_OP_CPU_KERNEL( crop_tensor, ops::CropTensorKernel, - ops::CropTensorKernel); + ops::CropTensorKernel, + ops::CropTensorKernel, + ops::CropTensorKernel); REGISTER_OP_CPU_KERNEL( crop_tensor_grad, ops::CropTensorGradKernel, - ops::CropTensorGradKernel); + ops::CropTensorGradKernel, + ops::CropTensorGradKernel, + ops::CropTensorGradKernel); diff --git a/paddle/fluid/operators/crop_tensor_op.cu b/paddle/fluid/operators/crop_tensor_op.cu index 9d28d984908..c3a144d1719 100644 --- a/paddle/fluid/operators/crop_tensor_op.cu +++ b/paddle/fluid/operators/crop_tensor_op.cu @@ -17,8 +17,12 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( crop_tensor, ops::CropTensorKernel, - ops::CropTensorKernel); + ops::CropTensorKernel, + ops::CropTensorKernel, + ops::CropTensorKernel); REGISTER_OP_CUDA_KERNEL( crop_tensor_grad, ops::CropTensorGradKernel, - ops::CropTensorGradKernel); + ops::CropTensorGradKernel, + ops::CropTensorGradKernel, + ops::CropTensorGradKernel); diff --git a/paddle/fluid/operators/crop_tensor_op.h b/paddle/fluid/operators/crop_tensor_op.h index 42f118d0220..b280d6ec911 100644 --- a/paddle/fluid/operators/crop_tensor_op.h +++ b/paddle/fluid/operators/crop_tensor_op.h @@ -50,29 +50,28 @@ inline std::vector get_new_data( } static framework::DDim ValidateShape(const std::vector shape, + const std::vector offsets, const framework::DDim& in_dims) { auto in_dim_size = in_dims.size(); auto shape_size = shape.size(); PADDLE_ENFORCE_EQ( in_dim_size, shape_size, - "Input(ShapeTensor)'s dimension size of Op(crop_tensor) should be equal " - "to that of input tensor. " + "Attr(shape)'s size of Op(crop_tensor) should be equal " + "to that of input Tensor. " "Please check the Attr(shape)'s size of Op(fluid.layers.crop_tensor)."); - const int64_t unk_dim_val = -1; - int unk_dim_idx = -1; std::vector output_shape(shape.size(), 0); for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == unk_dim_val) { - PADDLE_ENFORCE_EQ(unk_dim_idx, -1, - "Only one element of shape can be unknown."); - PADDLE_ENFORCE_EQ(i, 0, "Only the first element of shape can be -1."); - unk_dim_idx = i; + if (shape[i] <= 0 && in_dims[i] > 0) { + PADDLE_ENFORCE_NE( + shape[i], 0, + "The element in Attr(shape) of Op(crop_tensor) should not be zero."); + PADDLE_ENFORCE_EQ(shape[i], -1, + "When the element in Attr(shape) of Op(crop_tensor) is " + "negative, only -1 is supported."); + output_shape[i] = in_dims[i] - offsets[i]; } else { - PADDLE_ENFORCE_GT(shape[i], 0, - "Each element of shape must be greater than 0 " - "except the first element."); + output_shape[i] = static_cast(shape[i]); } - output_shape[i] = static_cast(shape[i]); } return framework::make_ddim(output_shape); @@ -164,21 +163,15 @@ void CropTensorFunction(const framework::ExecutionContext& context) { shape.push_back(out_dims[i]); } } - out_dims = ValidateShape(shape, x->dims()); - if (out_dims[0] == -1) { - out_dims[0] = x->dims()[0]; - } - out->mutable_data(out_dims, context.GetPlace()); - auto x_stride = framework::stride(x->dims()); auto offsets = GetOffsets(context); - int64_t offset = 0; + out_dims = ValidateShape(shape, offsets, x->dims()); + out->mutable_data(out_dims, context.GetPlace()); for (size_t i = 0; i < offsets.size(); ++i) { PADDLE_ENFORCE_LE( offsets[i] + shape[i], x_dims[i], "The sum of the Attr(offsets) and Attr(shape) of Op(crop_tensor) " "should be less than or equal to corresponding input dimension size."); - offset += (x_stride[i] * offsets[i]); } auto x_tensor = EigenTensor::From(*x); diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index d5b092ec99d..c17fe1348d7 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -14,7 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/lrn_op.h" #include +#include #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -23,18 +25,41 @@ namespace paddle { namespace operators { using framework::Tensor; +using DataLayout = framework::DataLayout; template struct LRNFunctor { void operator()(const framework::ExecutionContext& ctx, const framework::Tensor& input, framework::Tensor* out, framework::Tensor* mid, int N, int C, int H, int W, int n, - T k, T alpha, T beta) { - const T* idata = input.data(); + T k, T alpha, T beta, const DataLayout data_layout) { auto place = ctx.GetPlace(); auto blas = math::GetBlas(ctx); - T* odata = out->mutable_data(place); - T* mdata = mid->mutable_data(place); + math::Transpose transpose; + auto& dev_ctx = ctx.template device_context(); + Tensor in_transpose, mid_transpose, out_transpose; + // if channel_last, transpose to channel_first + if (data_layout == DataLayout::kNHWC) { + auto in_dims = input.dims(); + std::vector shape( + {in_dims[0], in_dims[3], in_dims[1], in_dims[2]}); + in_transpose.mutable_data(framework::make_ddim(shape), place); + mid_transpose.mutable_data(framework::make_ddim(shape), place); + out_transpose.mutable_data(framework::make_ddim(shape), place); + std::vector axis = {0, 3, 1, 2}; + transpose(dev_ctx, input, &in_transpose, axis); + } else { + in_transpose = input; + mid_transpose = *mid; + out_transpose = *out; + mid_transpose.mutable_data(mid->dims(), place); + out_transpose.mutable_data(out->dims(), place); + } + + const T* idata = in_transpose.data(); + T* odata = out_transpose.data(); + T* mdata = mid_transpose.data(); + Tensor squared; T* sdata = squared.mutable_data({1, C + n - 1, H, W}, place); std::memset(sdata, 0, sizeof(T) * squared.numel()); @@ -67,6 +92,13 @@ struct LRNFunctor { // compute the final output blas.VPOW(mid->numel(), mdata, -beta, odata); blas.VMUL(mid->numel(), odata, idata, odata); + + // if channel_last, transpose the output(NCHW) to channel_last + if (data_layout == DataLayout::kNHWC) { + std::vector axis = {0, 2, 3, 1}; + transpose(dev_ctx, mid_transpose, mid, axis); + transpose(dev_ctx, out_transpose, out, axis); + } } }; template struct LRNFunctor; @@ -78,7 +110,7 @@ struct LRNGradFunctor { const framework::Tensor& x, const framework::Tensor& out, const framework::Tensor& mid, framework::Tensor* x_g, const framework::Tensor& out_g, int N, int C, int H, int W, - int n, T alpha, T beta) { + int n, T alpha, T beta, const DataLayout data_layout) { T ratio = -2 * alpha * beta; auto x_g_e = framework::EigenVector::Flatten(*x_g); x_g_e = x_g_e.constant(0.0); @@ -93,17 +125,17 @@ struct LRNGradFunctor { const int end = start + n; for (int m = 0; m < N; m++) { for (int i = 0; i < C; i++) { - auto i_x = e_x.slice(Eigen::array({{m, i, 0, 0}}), - Eigen::array({{1, 1, H, W}})); - - auto i_x_g = e_x_g.slice(Eigen::array({{m, i, 0, 0}}), - Eigen::array({{1, 1, H, W}})); - - auto i_out_g = e_out_g.slice(Eigen::array({{m, i, 0, 0}}), - Eigen::array({{1, 1, H, W}})); + auto offsets = Eigen::array({{m, i, 0, 0}}); + auto extents = Eigen::array({{1, 1, H, W}}); + if (data_layout == DataLayout::kNHWC) { + offsets = Eigen::array({{m, 0, 0, i}}); + extents = Eigen::array({{1, H, W, 1}}); + } - auto i_mid = e_mid.slice(Eigen::array({{m, i, 0, 0}}), - Eigen::array({{1, 1, H, W}})); + auto i_x = e_x.slice(offsets, extents); + auto i_x_g = e_x_g.slice(offsets, extents); + auto i_out_g = e_out_g.slice(offsets, extents); + auto i_mid = e_mid.slice(offsets, extents); i_x_g = i_mid.pow(-beta) * i_out_g; for (int c = start; c < end; c++) { @@ -112,14 +144,14 @@ struct LRNGradFunctor { continue; } - auto c_out = e_out.slice(Eigen::array({{m, ch, 0, 0}}), - Eigen::array({{1, 1, H, W}})); - - auto c_mid = e_mid.slice(Eigen::array({{m, ch, 0, 0}}), - Eigen::array({{1, 1, H, W}})); - - auto c_out_g = e_out_g.slice(Eigen::array({{m, ch, 0, 0}}), - Eigen::array({{1, 1, H, W}})); + if (data_layout != DataLayout::kNHWC) { + offsets = Eigen::array({{m, ch, 0, 0}}); + } else { + offsets = Eigen::array({{m, 0, 0, ch}}); + } + auto c_out = e_out.slice(offsets, extents); + auto c_mid = e_mid.slice(offsets, extents); + auto c_out_g = e_out_g.slice(offsets, extents); i_x_g += ratio * c_out_g * c_out * i_x / c_mid; } @@ -156,9 +188,8 @@ class LRNOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { @@ -242,8 +273,8 @@ $$ Function implementation: -Inputs and outpus are in NCHW format, while input.shape.ndims() equals 4. -And dimensions 0 ~ 3 represent batch size, feature maps, rows, +Inputs and outpus are in NCHW or NHWC format, while input.shape.ndims() equals 4. +If NCHW, the dimensions 0 ~ 3 represent batch size, feature maps, rows, and columns, respectively. Input and Output in the formula above is for each map(i) of one image, and @@ -275,9 +306,8 @@ class LRNOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { diff --git a/paddle/fluid/operators/lrn_op.cu b/paddle/fluid/operators/lrn_op.cu index 64f3fea6be2..d71aaf73773 100644 --- a/paddle/fluid/operators/lrn_op.cu +++ b/paddle/fluid/operators/lrn_op.cu @@ -17,15 +17,20 @@ limitations under the License. */ namespace paddle { namespace operators { +using DataLayout = framework::DataLayout; + template __global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C, - int H, int W, int size, T k, T alpha) { + int H, int W, int size, T k, T alpha, + const DataLayout data_layout) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < img_size) { const int w = idx % W; const int h = (idx / W) % H; const int n = idx / W / H; - const int offset = (n * C * H + h) * W + w; + const int offset = + (data_layout != DataLayout::kNHWC ? (n * C * H + h) * W + w + : ((n * H + h) * W + w) * C); in += offset; mid += offset; @@ -37,15 +42,21 @@ __global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C, int index = 0; while (index < C + post_pad) { if (index < C) { - T val = in[index * step]; + int in_idx = (data_layout != DataLayout::kNHWC ? index * step : index); + T val = in[in_idx]; accum += val * val; } if (index >= size) { - T val = in[(index - size) * step]; + int in_idx = (data_layout != DataLayout::kNHWC ? (index - size) * step + : index - size); + T val = in[in_idx]; accum -= val * val; } if (index >= post_pad) { - mid[(index - post_pad) * step] = k + accum * alpha; + int mid_idx = + (data_layout != DataLayout::kNHWC ? (index - post_pad) * step + : index - post_pad); + mid[mid_idx] = k + accum * alpha; } ++index; } @@ -64,14 +75,14 @@ __global__ void KeCMRNormOutput(int input_size, const T* in, const T* mid, template void CrossMapNormal(const framework::ExecutionContext& ctx, const T* inputs, T* outputs, T* mid, int N, int C, int H, int W, int n, T k, - T alpha, T beta) { + T alpha, T beta, const DataLayout data_layout) { int img_size = N * H * W; const int block_size = 1024; int grid_size = (img_size + block_size - 1) / block_size; auto& dev_ctx = ctx.template device_context(); KeCMRNormFillScale<<>>( - img_size, inputs, mid, C, H, W, n, k, alpha); + img_size, inputs, mid, C, H, W, n, k, alpha, data_layout); int input_size = N * H * W * C; grid_size = (input_size + block_size - 1) / block_size; @@ -84,10 +95,11 @@ struct LRNFunctor { void operator()(const framework::ExecutionContext& ctx, const framework::Tensor& input, framework::Tensor* out, framework::Tensor* mid, int N, int C, int H, int W, int n, - T k, T alpha, T beta) { - CrossMapNormal( - ctx, input.data(), out->mutable_data(ctx.GetPlace()), - mid->mutable_data(ctx.GetPlace()), N, C, H, W, n, k, alpha, beta); + T k, T alpha, T beta, const DataLayout data_layout) { + CrossMapNormal(ctx, input.data(), + out->mutable_data(ctx.GetPlace()), + mid->mutable_data(ctx.GetPlace()), N, C, H, W, n, k, + alpha, beta, data_layout); } }; @@ -97,14 +109,16 @@ template struct LRNFunctor; template __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out, const T* mid, T* x_g, const T* out_g, int C, - int H, int W, int size, T negative_beta, - T ratio) { + int H, int W, int size, T negative_beta, T ratio, + const DataLayout data_layout) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < img_size) { const int w = idx % W; const int h = (idx / W) % H; const int n = idx / W / H; - const int offset = (n * C * H + h) * W + w; + const int offset = + (data_layout != DataLayout::kNHWC ? (n * C * H + h) * W + w + : ((n * H + h) * W + w) * C); x += offset; out += offset; mid += offset; @@ -120,18 +134,20 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out, // TODO(gongwb): optimize this with thread shared array. while (index < C + post_pad) { if (index < C) { - x_g[index * step] = 0.0; - accum += out_g[index * step] * out[index * step] / mid[index * step]; + int idx = (data_layout != DataLayout::kNHWC ? index * step : index); + x_g[idx] = 0.0; + accum += out_g[idx] * out[idx] / mid[idx]; } if (index >= size) { - accum -= out_g[(index - size) * step] * out[(index - size) * step] / - mid[(index - size) * step]; + int idx = (data_layout != DataLayout::kNHWC ? (index - size) * step + : index - size); + accum -= out_g[idx] * out[idx] / mid[idx]; } if (index >= post_pad) { - x_g[(index - post_pad) * step] += - out_g[(index - post_pad) * step] * - pow(mid[(index - post_pad) * step], negative_beta) - - ratio * x[(index - post_pad) * step] * accum; + int idx = (data_layout != DataLayout::kNHWC ? (index - post_pad) * step + : index - post_pad); + x_g[idx] += + out_g[idx] * pow(mid[idx], negative_beta) - ratio * x[idx] * accum; } ++index; } @@ -141,7 +157,8 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out, template void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x, const T* out, const T* mid, T* x_g, const T* out_g, - int N, int C, int H, int W, int n, T alpha, T beta) { + int N, int C, int H, int W, int n, T alpha, T beta, + const DataLayout data_layout) { int img_size = N * H * W; const int block_size = 1024; @@ -149,8 +166,8 @@ void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x, auto& dev_ctx = ctx.template device_context(); KeCMRNormDiff<<>>( - img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta, - 2.0f * alpha * beta); + img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta, 2.0f * alpha * beta, + data_layout); } template @@ -159,10 +176,10 @@ struct LRNGradFunctor { const framework::Tensor& x, const framework::Tensor& out, const framework::Tensor& mid, framework::Tensor* x_g, const framework::Tensor& out_g, int N, int C, int H, int W, - int n, T alpha, T beta) { + int n, T alpha, T beta, const DataLayout data_layout) { CrossMapNormalGrad(ctx, x.data(), out.data(), mid.data(), x_g->mutable_data(ctx.GetPlace()), out_g.data(), - N, C, H, W, n, alpha, beta); + N, C, H, W, n, alpha, beta, data_layout); } }; diff --git a/paddle/fluid/operators/lrn_op.h b/paddle/fluid/operators/lrn_op.h index 12d39c38153..44999970c20 100644 --- a/paddle/fluid/operators/lrn_op.h +++ b/paddle/fluid/operators/lrn_op.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include +#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" @@ -21,12 +23,15 @@ limitations under the License. */ namespace paddle { namespace operators { +using DataLayout = framework::DataLayout; + template struct LRNFunctor { void operator()(const framework::ExecutionContext& ctx, const framework::Tensor& input, framework::Tensor* out, framework::Tensor* mid, int N, int C, int H, int W, int n, - T k, T alpha, T beta); + T k, T alpha, T beta, + const DataLayout data_layout = DataLayout::kAnyLayout); }; template @@ -42,11 +47,14 @@ class LRNKernel : public framework::OpKernel { const Tensor& x = *ctx.Input("X"); auto x_dims = x.dims(); + const std::string data_layout_str = ctx.Attr("data_format"); + const framework::DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); // NCHW int N = x_dims[0]; - int C = x_dims[1]; - int H = x_dims[2]; - int W = x_dims[3]; + int C = (data_layout != DataLayout::kNHWC ? x_dims[1] : x_dims[3]); + int H = (data_layout != DataLayout::kNHWC ? x_dims[2] : x_dims[1]); + int W = (data_layout != DataLayout::kNHWC ? x_dims[3] : x_dims[2]); Tensor* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); @@ -65,7 +73,7 @@ class LRNKernel : public framework::OpKernel { PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0"); LRNFunctor f; - f(ctx, x, out, mid, N, C, H, W, n, k, alpha, beta); + f(ctx, x, out, mid, N, C, H, W, n, k, alpha, beta, data_layout); } }; @@ -75,7 +83,8 @@ struct LRNGradFunctor { const framework::Tensor& x, const framework::Tensor& out, const framework::Tensor& mid, framework::Tensor* x_g, const framework::Tensor& out_g, int N, int C, int H, int W, - int n, T alpha, T beta); + int n, T alpha, T beta, + const DataLayout data_layout = DataLayout::kAnyLayout); }; /** @@ -106,15 +115,18 @@ class LRNGradKernel : public framework::OpKernel { const Tensor& out = *ctx.Input("Out"); const Tensor& out_g = *ctx.Input(framework::GradVarName("Out")); const Tensor& mid = *ctx.Input("MidOut"); + const std::string data_layout_str = ctx.Attr("data_format"); + const framework::DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); auto x_g = ctx.Output(framework::GradVarName("X")); x_g->mutable_data(ctx.GetPlace()); auto x_dims = x.dims(); int N = x_dims[0]; - int C = x_dims[1]; - int H = x_dims[2]; - int W = x_dims[3]; + int C = (data_layout != DataLayout::kNHWC ? x_dims[1] : x_dims[3]); + int H = (data_layout != DataLayout::kNHWC ? x_dims[2] : x_dims[1]); + int W = (data_layout != DataLayout::kNHWC ? x_dims[3] : x_dims[2]); int n = ctx.Attr("n"); T alpha = ctx.Attr("alpha"); @@ -125,7 +137,7 @@ class LRNGradKernel : public framework::OpKernel { "is_test attribute should be set to False in training phase."); LRNGradFunctor f; - f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta); + f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta, data_layout); } }; diff --git a/paddle/fluid/operators/math/maxouting.cc b/paddle/fluid/operators/math/maxouting.cc index 730f71e96b6..45556e97d1d 100644 --- a/paddle/fluid/operators/math/maxouting.cc +++ b/paddle/fluid/operators/math/maxouting.cc @@ -18,35 +18,45 @@ namespace paddle { namespace operators { namespace math { -// All tensors are in NCHW format, and the groups must be greater than 1 +// All tensors are in NCHW or NHWC format, and the groups must be greater than 1 template class MaxOutFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, framework::Tensor* output, - int groups) { + const int groups, const int axis) { const int batch_size = input.dims()[0]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output->dims()[1]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output->dims()[axis]; int fea_size = input_height * input_width; // c_size means the output size of each sample int c_size = fea_size * output_channels; const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); - for (int i = 0; i < batch_size; ++i) { int new_bindex = c_size * i; for (int c = 0; c < output_channels; ++c) { int new_cindex = fea_size * c; for (int f = 0; f < fea_size; ++f) { T ele = static_cast(-FLT_MAX); + int input_idx, output_idx; for (int ph = 0; ph < groups; ++ph) { - T x = input_data[(new_bindex + new_cindex) * groups + - ph * fea_size + f]; + if (axis == 1) { + input_idx = + (new_bindex + new_cindex) * groups + ph * fea_size + f; + } else { + input_idx = (new_bindex + f * output_channels + c) * groups + ph; + } + T x = input_data[input_idx]; ele = ele > x ? ele : x; } - output_data[(new_bindex + new_cindex + f)] = ele; + if (axis == 1) { + output_idx = new_bindex + new_cindex + f; + } else { + output_idx = new_bindex + f * output_channels + c; + } + output_data[output_idx] = ele; } } } @@ -59,11 +69,12 @@ class MaxOutGradFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, int groups) { + const framework::Tensor& output_grad, const int groups, + const int axis) { const int batch_size = input.dims()[0]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output.dims()[axis]; int fea_size = input_height * input_width; const T* input_data = input.data(); const T* output_data = output.data(); @@ -75,11 +86,18 @@ class MaxOutGradFunctor { for (int c = 0; c < output_channels; ++c) { int clen = fea_size * c; for (int f = 0; f < fea_size; ++f) { - int input_idx0 = (blen + clen) * groups + f; + int input_idx0, output_idx; bool continue_match = true; - int output_idx = blen + clen + f; + if (axis == 1) { + input_idx0 = (blen + clen) * groups + f; + output_idx = blen + clen + f; + } else { + input_idx0 = (blen + f * output_channels + c) * groups; + output_idx = blen + f * output_channels + c; + } for (int g = 0; g < groups && continue_match; ++g) { - int input_idx = input_idx0 + fea_size * g; + int idx_offset = (axis == 1 ? fea_size * g : g); + int input_idx = input_idx0 + idx_offset; if (input_data[input_idx] == output_data[output_idx]) { input_grad_data[input_idx] += output_grad_data[output_idx]; continue_match = false; diff --git a/paddle/fluid/operators/math/maxouting.cu b/paddle/fluid/operators/math/maxouting.cu index d9a23299a4d..8b134a29d81 100644 --- a/paddle/fluid/operators/math/maxouting.cu +++ b/paddle/fluid/operators/math/maxouting.cu @@ -22,8 +22,8 @@ namespace math { template __global__ void KernelMaxOut(const int nthreads, const T* input_data, const int channels, const int input_height, - const int input_width, int groups, - T* output_data) { + const int input_width, const int groups, + const int axis, T* output_data) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; int index = blockIdx.x * blockDim.x + threadIdx.x; @@ -31,13 +31,22 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data, for (int i = index; i < nthreads; i += offset) { int batch_idx = i / size; int batch_offset = i % size; - int channel_idx = batch_offset / feat_len; - int feat_idx = batch_offset % feat_len; - int data_idx = - (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + int channel_idx, feat_idx, data_idx; + if (axis == 1) { + channel_idx = batch_offset / feat_len; + feat_idx = batch_offset % feat_len; + data_idx = + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + } else { + channel_idx = batch_offset % channels; + feat_idx = batch_offset / channels; + data_idx = + (batch_idx * size + feat_idx * channels + channel_idx) * groups; + } T ele = static_cast(-FLT_MAX); for (int g = 0; g < groups; ++g) { - T x = input_data[data_idx + g * feat_len]; + int idx_offset = (axis == 1 ? g * feat_len : g); + T x = input_data[data_idx + idx_offset]; ele = ele > x ? ele : x; } output_data[i] = ele; @@ -48,7 +57,7 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, const T* output_data, const T* output_grad, T* input_grad, const int channels, const int input_height, const int input_width, - int groups) { + const int groups, const int axis) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; int index = blockIdx.x * blockDim.x + threadIdx.x; @@ -56,15 +65,24 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, for (int i = index; i < nthreads; i += offset) { int batch_idx = i / size; int batch_offset = i % size; - int channel_idx = batch_offset / feat_len; - int feat_idx = batch_offset % feat_len; - int data_idx = - (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + int channel_idx, feat_idx, data_idx; + if (axis == 1) { + channel_idx = batch_offset / feat_len; + feat_idx = batch_offset % feat_len; + data_idx = + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + } else { + channel_idx = batch_offset % channels; + feat_idx = batch_offset / channels; + data_idx = + (batch_idx * size + feat_idx * channels + channel_idx) * groups; + } int max_index = -1; bool continue_match = true; for (int g = 0; g < groups && continue_match; ++g) { - if (input_data[data_idx + g * feat_len] == output_data[i]) { - max_index = data_idx + g * feat_len; + int idx_offset = (axis == 1 ? g * feat_len : g); + if (input_data[data_idx + idx_offset] == output_data[i]) { + max_index = data_idx + idx_offset; continue_match = false; break; } @@ -75,21 +93,19 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, } } /* - * All tensors are in NCHW format. + * All tensors are in NCHW or NHWC format. */ template class MaxOutFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* output, - int groups) { + const int groups, const int axis) { const int batch_size = input.dims()[0]; - const int input_channels = input.dims()[1]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output->dims()[1]; - const int output_height = output->dims()[2]; - const int output_width = output->dims()[3]; + const int input_channels = input.dims()[axis]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output->dims()[axis]; const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); @@ -100,11 +116,11 @@ class MaxOutFunctor { KernelMaxOut<<>>( nthreads, input_data, input_channels, input_height, input_width, groups, - output_data); + axis, output_data); } }; /* - * All tensors are in NCHW format. + * All tensors are in NCHW or NHWC format. */ template class MaxOutGradFunctor { @@ -112,14 +128,13 @@ class MaxOutGradFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, int groups) { + const framework::Tensor& output_grad, const int groups, + const int axis) { const int batch_size = input.dims()[0]; - const int input_channels = input.dims()[1]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int input_channels = input.dims()[axis]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output.dims()[axis]; const T* input_data = input.data(); const T* output_data = output.data(); @@ -132,7 +147,7 @@ class MaxOutGradFunctor { KernelMaxoutGrad<<>>( nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_height, input_width, groups); + input_channels, input_height, input_width, groups, axis); } }; diff --git a/paddle/fluid/operators/math/maxouting.h b/paddle/fluid/operators/math/maxouting.h index e4d378dc232..50bddf73bc1 100644 --- a/paddle/fluid/operators/math/maxouting.h +++ b/paddle/fluid/operators/math/maxouting.h @@ -26,7 +26,8 @@ template class MaxOutFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, - framework::Tensor* output, int groups); + framework::Tensor* output, const int groups, + const int axis = 1); }; template @@ -35,7 +36,8 @@ class MaxOutGradFunctor { void operator()(const DeviceContext& context, const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, int groups); + const framework::Tensor& output_grad, const int groups, + const int axis = 1); }; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/maxout_op.cc b/paddle/fluid/operators/maxout_op.cc index c05c1a282c2..87301178437 100644 --- a/paddle/fluid/operators/maxout_op.cc +++ b/paddle/fluid/operators/maxout_op.cc @@ -23,25 +23,27 @@ using framework::Tensor; class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput( - "X", - "(Tensor) The input tensor of maxout operator with data type of " - "float32. The format of input tensor is NCHW. Where N is batch size," - " C is the number of channels, H and W is the height and width of " - "feature."); + AddInput("X", + "A 4-D Tensor with data type of float32 or float64. " + "The data format is NCHW or NHWC. Where N is " + "batch size, C is the number of channels, " + "H and W is the height and width of " + "feature. "); AddOutput("Out", - "(Tensor) The output tensor of maxout operator." - "The data type is float32." - "The format of output tensor is also NCHW." - "Where N is batch size, C is " - "the number of channels, H and W is the height and " - "width of feature."); + "A 4-D Tensor with same data type and data format " + "with input Tensor. "); AddAttr( "groups", - "(int)," - "Specifies how many groups the input tensor will be split" - "in the channel dimension. And the number of output channel is " - "the number of channels divided by groups."); + "Specifies how many groups the input tensor will be split into " + "at the channel dimension. And the number of output channel is " + "the number of channels divided by groups. "); + AddAttr( + "axis", + "Specifies the index of channel dimension where maxout will " + "be performed. It should be 1 when data format is NCHW, -1 or 3 " + "when data format is NHWC. " + "Default: 1. ") + .SetDefault(1); AddComment(R"DOC( MaxOut Operator. @@ -70,17 +72,19 @@ class MaxOutOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of MaxoutOpshould not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of MaxoutOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of MaxoutOpshould not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of MaxoutOp should not be null."); auto in_x_dims = ctx->GetInputDim("X"); int groups = ctx->Attrs().Get("groups"); + int axis = ctx->Attrs().Get("axis"); // check groups > 1 - PADDLE_ENFORCE_GT(groups, 1, "groups should be larger than 1 in maxoutop"); - std::vector output_shape({in_x_dims[0], in_x_dims[1] / groups}); - output_shape.push_back(in_x_dims[2]); - output_shape.push_back(in_x_dims[3]); + PADDLE_ENFORCE_GT(groups, 1, + "Attr(groups) of Op(maxout) should be larger than 1."); + std::vector output_shape( + {in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]}); + output_shape[axis] = in_x_dims[axis] / groups; ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } }; diff --git a/paddle/fluid/operators/maxout_op.h b/paddle/fluid/operators/maxout_op.h index 5b9e003cb09..ec3897e4044 100644 --- a/paddle/fluid/operators/maxout_op.h +++ b/paddle/fluid/operators/maxout_op.h @@ -30,10 +30,11 @@ class MaxOutKernel : public framework::OpKernel { const Tensor* in_x = context.Input("X"); Tensor* out = context.Output("Out"); int groups = context.template Attr("groups"); + int axis = context.template Attr("axis"); math::MaxOutFunctor maxout_forward; maxout_forward(context.template device_context(), *in_x, out, - groups); + groups, axis); } }; @@ -47,13 +48,15 @@ class MaxOutGradKernel : public framework::OpKernel { context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); int groups = context.template Attr("groups"); + int axis = context.template Attr("axis"); auto& device_ctx = context.template device_context(); math::SetConstant zero; if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); zero(device_ctx, in_x_grad, static_cast(0.0)); math::MaxOutGradFunctor maxout_backward; - maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups); + maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups, + axis); } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 5113e475289..a975ca5b609 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9334,7 +9334,8 @@ def lod_append(x, level): return out -def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): +def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None, + data_format='NCHW'): """ This operator implements the Local Response Normalization Layer. This layer performs a type of "lateral inhibition" by normalizing over local input regions. @@ -9355,13 +9356,18 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): Args: - input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W], where N is the batch size, C is the input channel, H is Height, W is weight. The data type is float32. The rank of this tensor must be 4, otherwise it will raise ValueError. + input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W] or [N, H, W, C], + where N is the batch size, C is the input channel, H is Height, W is weight. The data + type is float32. The rank of this tensor must be 4, otherwise it will raise ValueError. n (int, optional): The number of channels to sum over. Default: 5 k (float, optional): An offset, positive. Default: 1.0 alpha (float, optional): The scaling parameter, positive. Default:1e-4 beta (float, optional): The exponent, positive. Default:0.75 - name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` - + name (str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name` + data_format(str, optional): The data format of the input and output data. An optional string + from: `"NCHW"`, `"NHWC"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. Default: 'NCHW'. Returns: Variable: A tensor variable storing the transformation result with the same shape and data type as input. @@ -9384,8 +9390,12 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): if dims != 4: raise ValueError( - "dims of input must be 4(not %d), and it's order must be NCHW" % + "Input's dimension size of Op(lrn) must be 4, but received %d." % (dims)) + if data_format not in ['NCHW', 'NHWC']: + raise ValueError( + "Attr(data_format) of Op(lrn) got wrong value: received " + + data_format + " but only NCHW or NHWC supported.") mid_out = helper.create_variable_for_type_inference( dtype=dtype, stop_gradient=True) @@ -9397,10 +9407,13 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): "Out": lrn_out, "MidOut": mid_out, }, - attrs={"n": n, - "k": k, - "alpha": alpha, - "beta": beta}) + attrs={ + "n": n, + "k": k, + "alpha": alpha, + "beta": beta, + "data_format": data_format + }) return lrn_out @@ -11547,7 +11560,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None): * Case 1 (input is a 2-D Tensor): Input: - X.shape = [3. 5] + X.shape = [3, 5] X.data = [[0, 1, 2, 0, 0], [0, 3, 4, 0, 0], [0, 0, 0, 0, 0]] @@ -11555,8 +11568,9 @@ def crop_tensor(x, shape=None, offsets=None, name=None): shape = [2, 2] offsets = [0, 1] Output: - Out = [[1, 2], - [3, 4]] + Out.shape = [2, 2] + Out.data = [[1, 2], + [3, 4]] * Case 2 (input is a 3-D Tensor): Input: X.shape = [2, 3, 4] @@ -11567,24 +11581,23 @@ def crop_tensor(x, shape=None, offsets=None, name=None): [0, 6, 7, 8], [0, 0, 0, 0]]] Parameters: - shape = [2, 2, 3] + shape = [2, 2, -1] offsets = [0, 0, 1] Output: - Out = [[[1, 2, 3], - [5, 6, 7]], - [[3, 4, 5], - [6, 7, 8]]] + Out.shape = [2, 2, 3] + Out.data = [[[1, 2, 3], + [5, 6, 7]], + [[3, 4, 5], + [6, 7, 8]]] Parameters: - x (Variable): 1-D to 6-D Tensor, the data type is float32 or float64. + x (Variable): 1-D to 6-D Tensor, the data type is float32, float64, int32 or int64. shape (list|tuple|Variable): The output shape is specified by `shape`. Its data type is int32. If a list/tuple, it's length must be the same as the dimension size of `x`. If a Variable, it shoule be a 1-D Tensor. When it is a list, each element can be an integer or a Tensor of shape: [1]. - If Variable contained, it is suitable for the case that the shape may - be changed each iteration. Only the first element of list/tuple can be - set to -1, it means that the first dimension's size of the output is the same - as the input. + If Variable contained, it is suitable for the case that the shape may + be changed each iteration. offsets (list|tuple|Variable, optional): Specifies the cropping offsets at each dimension. Its data type is int32. If a list/tuple, it's length must be the same as the dimension size of `x`. If a Variable, it shoule be a 1-D @@ -11598,8 +11611,12 @@ def crop_tensor(x, shape=None, offsets=None, name=None): Variable: The cropped Tensor has same data type with `x`. Raises: - ValueError: If shape is not a list, tuple or Variable. - ValueError: If offsets is not None and not a list, tuple or Variable. + TypeError: If the data type of `x` is not in: float32, float64, int32, int64. + TypeError: If `shape` is not a list, tuple or Variable. + TypeError: If the data type of `shape` is not int32. + TypeError: If `offsets` is not None and not a list, tuple or Variable. + TypeError: If the data type of `offsets` is not int32. + ValueError: If the element in `offsets` is less than zero. Examples: @@ -11615,7 +11632,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None): # crop0.shape = [-1, -1, -1], it means crop0.shape[0] = x.shape[0] in runtime. # or shape is a list in which each element is a constant - crop1 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3]) + crop1 = fluid.layers.crop_tensor(x, shape=[-1, -1, 3], offsets=[0, 1, 0]) # crop1.shape = [-1, 2, 3] # or shape is a list in which each element is a constant or Variable @@ -11637,70 +11654,98 @@ def crop_tensor(x, shape=None, offsets=None, name=None): """ helper = LayerHelper('crop_tensor', **locals()) + if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']: + raise TypeError( + "Input(x)'s dtype of Op(crop_tensor) must be float32, float64, int32 or int64, " + "but received %s." % (convert_dtype(x.dtype))) + if not (isinstance(shape, list) or isinstance(shape, tuple) or \ isinstance(shape, Variable)): - raise ValueError("The shape should be a list, tuple or Variable.") + raise TypeError( + "Attr(shape) of Op(crop_tensor) should be a list, tuple or Variable." + ) if offsets is None: offsets = [0] * len(x.shape) if not (isinstance(offsets, list) or isinstance(offsets, tuple) or \ isinstance(offsets, Variable)): - raise ValueError("The offsets should be a list, tuple or Variable.") + raise TypeError( + "Attr(offsets) of Op(crop_tensor) should be a list, tuple or Variable." + ) out = helper.create_variable_for_type_inference(x.dtype) ipts = {'X': x} attrs = {} - def contain_var(input_list): + def _contain_var(input_list): for ele in input_list: if isinstance(ele, Variable): return True return False + def _attr_shape_check(shape_val): + if not isinstance(shape_val, int): + raise TypeError( + "Attr(shape)'s dtype of Op(crop_tensor) should be int32, but received: %s." + % type(shape_val)) + if shape_val == 0: + raise ValueError( + "Attr(shape) of Op(crop_tensor) should not be zero, but received: %s." + % str(shape_val)) + if shape_val < -1: + raise ValueError( + "When the element in Attr(shape) of Op(crop_tensor) is negative, only -1 is supported, but received: %s." + % str(shape_val)) + + def _attr_offsets_check(offset_val): + if not isinstance(offset_val, int): + raise TypeError( + "Attr(offsets)'s dtype of Op(crop_tensor) should be int32, but received: %s." + % type(offset_val)) + if offset_val < 0: + raise ValueError( + "Attr(offsets) of Op(crop_tensor) should be greater or equal to zero, but received: %s." + % str(offset_val)) + if isinstance(offsets, Variable): offsets.stop_gradient = True ipts['Offsets'] = offsets - elif contain_var(offsets): + attrs['offsets'] = [-1] * len(x.shape) + elif _contain_var(offsets): new_offsets_tensor = [] + offsets_attr = [] for dim in offsets: if isinstance(dim, Variable): dim.stop_gradient = True new_offsets_tensor.append(dim) + offsets_attr.append(-1) else: - assert (isinstance(dim, int)) - assert dim >= 0, ("offsets should be greater or equal to zero.") + _attr_offsets_check(dim) temp_out = helper.create_variable_for_type_inference('int32') fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out) new_offsets_tensor.append(temp_out) + offsets_attr.append(dim) ipts['OffsetsTensor'] = new_offsets_tensor + attrs['offsets'] = offsets_attr else: + for offset in offsets: + _attr_offsets_check(offset) attrs['offsets'] = offsets - unk_dim_idx = -1 if isinstance(shape, Variable): shape.stop_gradient = True ipts['Shape'] = shape - elif contain_var(shape): + elif _contain_var(shape): new_shape_tensor = [] shape_attr = [] - for dim_idx, dim_size in enumerate(shape): + for dim_size in shape: if isinstance(dim_size, Variable): dim_size.stop_gradient = True new_shape_tensor.append(dim_size) - shape_attr.append(-1) + shape_attr.append(0) else: - assert (isinstance(dim_size, int)) - if dim_size == -1: - assert unk_dim_idx == -1, ( - "Only one element in shape can be unknown.") - assert dim_idx == 0, ( - "Only the first element in shape can be -1.") - unk_dim_idx = dim_idx - else: - assert dim_size > 0, ( - "Each dimension size given in shape must be greater than zero." - ) + _attr_shape_check(dim_size) temp_out = helper.create_variable_for_type_inference('int32') fill_constant( [1], 'int32', dim_size, force_cpu=True, out=temp_out) @@ -11709,6 +11754,8 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ipts['ShapeTensor'] = new_shape_tensor attrs['shape'] = shape_attr else: + for dim_size in shape: + _attr_shape_check(dim_size) attrs['shape'] = shape helper.append_op( @@ -15195,22 +15242,23 @@ def sigmoid_cross_entropy_with_logits(x, @templatedoc() -def maxout(x, groups, name=None): +def maxout(x, groups, name=None, axis=1): """ ${comment} Args: x(${x_type}): ${x_comment} - groups(${groups_type}): ${groups_comment} + groups(int): ${groups_comment} + axis(int, optional): ${axis_comment} name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. Returns: - Variable: - - out(${out_type}): ${out_comment} + Variable: ${out_comment} + Raises: + ValueError: If `axis` is not 1, -1 or 3. Examples: .. code-block:: python @@ -15223,6 +15271,12 @@ def maxout(x, groups, name=None): out = fluid.layers.maxout(input, groups=2) """ helper = LayerHelper("maxout", **locals()) + if axis not in [1, -1, 3]: + raise ValueError( + "Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received " + "Attr(axis): %s." % str(axis)) + if axis == -1: + axis = 3 if name is None: out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -15233,7 +15287,8 @@ def maxout(x, groups, name=None): helper.append_op( type="maxout", inputs={"X": x}, - attrs={"groups": groups}, + attrs={"groups": groups, + "axis": axis}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py b/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py index ed04a850042..5864e15df13 100644 --- a/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py +++ b/python/paddle/fluid/tests/unittests/test_crop_tensor_op.py @@ -44,13 +44,13 @@ def crop(data, offsets, crop_shape): class TestCropTensorOp(OpTest): def setUp(self): self.op_type = "crop_tensor" - self.crop_by_1D_shape = False + self.shape_by_input = False self.offset_by_input = False self.unk_dim_idx = -1 self.attrs = {} self.initTestCase() - if self.crop_by_1D_shape: + if self.shape_by_input: self.inputs = { 'X': np.random.random(self.x_shape).astype("float32"), 'Shape': np.array(self.crop_shape).astype("int32") @@ -65,11 +65,11 @@ class TestCropTensorOp(OpTest): else: self.attrs['offsets'] = self.offsets - if self.unk_dim_idx != -1: - self.crop_shape[self.unk_dim_idx] = self.x_shape[self.unk_dim_idx] - self.outputs = { - 'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) - } + crop_shape = [val for val in self.crop_shape] + for i in range(len(self.crop_shape)): + if self.crop_shape[i] == -1: + crop_shape[i] = self.x_shape[i] - self.offsets[i] + self.outputs = {'Out': crop(self.inputs['X'], self.offsets, crop_shape)} def initTestCase(self): self.x_shape = (8, 8) @@ -93,9 +93,8 @@ class TestCase1(TestCropTensorOp): class TestCase2(TestCropTensorOp): def initTestCase(self): self.x_shape = (12, 24) - self.crop_shape = [-1, 8] #only the first dimension (batch) can be -1 + self.crop_shape = [-1, 8] self.offsets = [0, 0] - self.unk_dim_idx = 0 class TestCase3(TestCropTensorOp): @@ -103,16 +102,15 @@ class TestCase3(TestCropTensorOp): self.x_shape = (4, 8, 16) self.crop_shape = [2, 2, 3] self.offsets = [1, 5, 3] - self.crop_by_1D_shape = True + self.shape_by_input = True class TestCase4(TestCropTensorOp): def initTestCase(self): self.x_shape = (8, 3, 6, 6) - self.crop_shape = [-1, 3, 4, 4] - self.offsets = [0, 0, 0, 0] - self.crop_by_1D_shape = True - self.unk_dim_idx = 0 + self.crop_shape = [-1, 3, -1, 4] + self.offsets = [0, 0, 1, 0] + self.shape_by_input = True class TestCase5(TestCropTensorOp): @@ -128,14 +126,13 @@ class TestCase6(TestCropTensorOp): self.x_shape = (2, 2, 4, 4, 4, 2) self.crop_shape = [1, 1, 4, 2, 2, 2] self.offsets = [0, 0, 0, 0, 0, 0] - self.crop_by_1D_shape = True + self.shape_by_input = True self.offset_by_input = True -class TestCropTensorOp_attr_tensor(OpTest): +class TestCropTensorOpTensorAttr(OpTest): def setUp(self): self.op_type = "crop_tensor" - self.mixed_type = False self.OffsetsTensor = False self.ShapeTensor = True self.attrs = {} @@ -150,8 +147,7 @@ class TestCropTensorOp_attr_tensor(OpTest): 'X': np.random.random(self.x_shape).astype("float32"), 'ShapeTensor': shape_tensor } - if self.mixed_type: - self.attrs['shape'] = self.shape_attr + self.attrs['shape'] = self.shape_attr if self.OffsetsTensor: offsets_tensor = [] @@ -162,17 +158,21 @@ class TestCropTensorOp_attr_tensor(OpTest): 'X': np.random.random(self.x_shape).astype("float32"), 'OffsetsTensor': offsets_tensor } - else: - self.attrs['offsets'] = self.offsets + self.attrs['offsets'] = self.offsets_attr - self.outputs = { - 'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) - } + self.attrs['shape'] = self.crop_shape + self.attrs['offsets'] = self.offsets + crop_shape = [val for val in self.crop_shape] + for i in range(len(self.crop_shape)): + if self.crop_shape[i] == -1: + crop_shape[i] = self.x_shape[i] - self.offsets[i] + self.outputs = {'Out': crop(self.inputs['X'], self.offsets, crop_shape)} def initTestCase(self): self.x_shape = (8, 8) self.crop_shape = (2, 2) self.offsets = [1, 2] + self.shape_attr = [0, 0] def test_check_output(self): self.check_output() @@ -181,38 +181,85 @@ class TestCropTensorOp_attr_tensor(OpTest): self.check_grad(["X"], "Out", max_relative_error=0.006) -class TestCropTensorOp_attr_tensor_case1(TestCropTensorOp_attr_tensor): - def init_data(self): +class TestCropTensorOpTensorAttrCase1(TestCropTensorOpTensorAttr): + def initTestCase(self): self.x_shape = (16, 8, 32) - self.crop_shape = [2, 2, 3] + self.crop_shape = [-1, -1, 3] self.offsets = [1, 5, 3] + self.shape_attr = [-1, -1, 3] -class TestCropTensorOp_attr_tensor_case2(TestCropTensorOp_attr_tensor): - def init_data(self): +class TestCropTensorOpTensorAttrCase2(TestCropTensorOpTensorAttr): + def initTestCase(self): self.x_shape = (4, 8, 16, 8) self.crop_shape = [2, 2, 3, 4] self.offsets = [1, 5, 3, 0] - self.shape_attr = [-1, -1, 3, 4] - self.mixed_type = True + self.shape_attr = [0, 0, 3, 4] -class TestCropTensorOp_attr_tensor_case3(TestCropTensorOp_attr_tensor): - def init_data(self): +class TestCropTensorOpTensorAttrCase3(TestCropTensorOpTensorAttr): + def initTestCase(self): self.x_shape = (16, 8, 32) self.crop_shape = [2, 2, 3] self.offsets = [1, 5, 3] + self.offsets_attr = [-1, -1, 3] self.ShapeTensor = False self.OffsetsTensor = True -class TestCropTensorOp_attr_tensor_case4(TestCropTensorOp_attr_tensor): - def init_data(self): +class TestCropTensorOpTensorAttrCase4(TestCropTensorOpTensorAttr): + def initTestCase(self): self.x_shape = (16, 8, 32) self.crop_shape = [2, 2, 3] + self.shape_attr = [0, 2, 3] self.offsets = [1, 5, 3] + self.offsets_attr = [-1, -1, 3] self.OffsetsTensor = True +class TestCropTensorException(OpTest): + def test_exception(self): + input1 = fluid.data(name="input1", shape=[2, 3, 6, 6], dtype="float32") + input2 = fluid.data(name="input2", shape=[2, 3, 6, 6], dtype="float16") + dim = fluid.data(name='dim', shape=[1], dtype='int32') + offset = fluid.data(name='offset', shape=[1], dtype='int32') + + def attr_shape_type(): + out = fluid.layers.crop_tensor(input1, shape=3) + + def attr_shape_dtype(): + out = fluid.layers.crop_tensor(input1, shape=[2, 2.0, 3, 3]) + + def attr_shape_value1(): + out = fluid.layers.crop_tensor(input1, shape=[2, -2, dim, 3]) + + def attr_shape_value2(): + out = fluid.layers.crop_tensor(input1, shape=[2, 0, dim, 3]) + + def attr_offsets_type(): + out = fluid.layers.crop_tensor( + input1, shape=[2, 2, 3, 3], offsets=0) + + def attr_offsets_dtype(): + out = fluid.layers.crop_tensor( + input1, shape=[2, 2, 3, 3], offsets=[0, 1.0, 0, 0]) + + def attr_offsets_value(): + out = fluid.layers.crop_tensor( + input1, shape=[2, 2, 3, 3], offsets=[0, -1, offset, 0]) + + def input_dtype(): + out = fluid.layers.crop_tensor(input2, shape=[2, 2, 3, 3]) + + self.assertRaises(TypeError, attr_shape_type) + self.assertRaises(TypeError, attr_shape_dtype) + self.assertRaises(ValueError, attr_shape_value1) + self.assertRaises(ValueError, attr_shape_value2) + self.assertRaises(TypeError, attr_offsets_type) + self.assertRaises(TypeError, attr_offsets_dtype) + self.assertRaises(ValueError, attr_offsets_value) + self.assertRaises(TypeError, input_dtype) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lrn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_op.py index bb91f26bbb5..7ec329d9f6c 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lrn_op.py @@ -16,6 +16,8 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid as fluid +import paddle.fluid.core as core from op_test import OpTest @@ -60,12 +62,15 @@ class TestLRNOp(OpTest): 'n': self.n, 'k': self.k, 'alpha': self.alpha, - 'beta': self.beta + 'beta': self.beta, + 'data_format': self.data_format } return attrs def setUp(self): self.op_type = "lrn" + self.init_test_case() + self.N = 2 self.C = 3 self.H = 5 @@ -77,11 +82,18 @@ class TestLRNOp(OpTest): self.beta = 0.75 self.x = self.get_input() self.out, self.mid_out = self.get_out() + if self.data_format == 'NHWC': + self.x = np.transpose(self.x, [0, 2, 3, 1]) + self.out = np.transpose(self.out, [0, 2, 3, 1]) + self.mid_out = np.transpose(self.mid_out, [0, 2, 3, 1]) self.inputs = {'X': self.x} self.outputs = {'Out': self.out, 'MidOut': self.mid_out} self.attrs = self.get_attrs() + def init_test_case(self): + self.data_format = 'NCHW' + def test_check_output(self): self.check_output() @@ -89,5 +101,49 @@ class TestLRNOp(OpTest): self.check_grad(['X'], 'Out', max_relative_error=0.01) +class TestLRNOpAttrDataFormat(TestLRNOp): + def init_test_case(self): + self.data_format = 'NHWC' + + +class TestLRNAPI(OpTest): + def test_case(self): + data1 = fluid.data(name='data1', shape=[2, 4, 5, 5], dtype='float32') + data2 = fluid.data(name='data2', shape=[2, 5, 5, 4], dtype='float32') + out1 = fluid.layers.lrn(data1, data_format='NCHW') + out2 = fluid.layers.lrn(data2, data_format='NHWC') + data1_np = np.random.random((2, 4, 5, 5)).astype("float32") + data2_np = np.transpose(data1_np, [0, 2, 3, 1]) + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run(fluid.default_main_program(), + feed={"data1": data1_np, + "data2": data2_np}, + fetch_list=[out1, out2], + return_numpy=True) + + self.assertTrue( + np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2)))) + + def test_exception(self): + input1 = fluid.data(name="input1", shape=[2, 4, 5, 5], dtype="float32") + input2 = fluid.data( + name="input2", shape=[2, 4, 5, 5, 5], dtype="float32") + + def _attr_data_fromat(): + out = fluid.layers.lrn(input1, data_format='NDHW') + + def _input_dim_size(): + out = fluid.layers.lrn(input2) + + self.assertRaises(ValueError, _attr_data_fromat) + self.assertRaises(ValueError, _input_dim_size) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_maxout_op.py b/python/paddle/fluid/tests/unittests/test_maxout_op.py index d588b22fe26..19c517142fd 100644 --- a/python/paddle/fluid/tests/unittests/test_maxout_op.py +++ b/python/paddle/fluid/tests/unittests/test_maxout_op.py @@ -16,11 +16,16 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid as fluid +import paddle.fluid.core as core from op_test import OpTest -def maxout_forward_naive(input, groups): +def maxout_forward_naive(input, groups, channel_axis): s0, s1, s2, s3 = input.shape + if channel_axis == 3: + return np.ndarray([s0, s1, s2, s3 // groups, groups], \ + buffer = input, dtype=input.dtype).max(axis=(4)) return np.ndarray([s0, s1 // groups, groups, s2, s3], \ buffer = input, dtype=input.dtype).max(axis=(2)) @@ -30,10 +35,11 @@ class TestMaxOutOp(OpTest): self.op_type = "maxout" self.init_test_case() input = np.random.random(self.shape).astype("float32") - output = self.MaxOut_forward_naive(input, self.groups).astype("float32") + output = self.MaxOut_forward_naive(input, self.groups, + self.axis).astype("float32") self.inputs = {'X': input} - self.attrs = {'groups': self.groups} + self.attrs = {'groups': self.groups, 'axis': self.axis} self.outputs = {'Out': output.astype('float32')} @@ -47,6 +53,48 @@ class TestMaxOutOp(OpTest): self.MaxOut_forward_naive = maxout_forward_naive self.shape = [100, 6, 2, 2] self.groups = 2 + self.axis = 1 + + +class TestMaxOutOpAxis(TestMaxOutOp): + def init_test_case(self): + self.MaxOut_forward_naive = maxout_forward_naive + self.shape = [100, 2, 2, 6] # NHWC format + self.groups = 2 + self.axis = 3 + + +class TestMaxOutOpAxisAPI(OpTest): + def test_axis(self): + data1 = fluid.data(name='data1', shape=[3, 6, 2, 2], dtype='float32') + data2 = fluid.data(name='data2', shape=[3, 2, 2, 6], dtype='float32') + out1 = fluid.layers.maxout(data1, groups=2, axis=1) + out2 = fluid.layers.maxout(data2, groups=2, axis=-1) + data1_np = np.random.random((3, 6, 2, 2)).astype("float32") + data2_np = np.transpose(data1_np, [0, 2, 3, 1]) + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run(fluid.default_main_program(), + feed={"data1": data1_np, + "data2": data2_np}, + fetch_list=[out1, out2], + return_numpy=True) + + self.assertTrue( + np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2)))) + + def test_exception(self): + input = fluid.data(name="input", shape=[2, 4, 6, 6], dtype="float32") + + def _attr_axis(): + out = fluid.layers.maxout(input, groups=2, axis=2) + + self.assertRaises(ValueError, _attr_axis) if __name__ == '__main__': -- GitLab