未验证 提交 3848f720 编写于 作者: Z Zhang Ting 提交者: GitHub

[cherry-pick] fix crop_tensor, maxout and lrn (#21302)

* [cherry-pick] All elements in attr(shape) of crop_tensor can be -1 and int32/64 kernel registered (#20756)

* All elements in attr(shape) of crop_tensor can be -1, test=develop, test=document_preview

* fix the bug that attr(offsets) should be initialized, test=develop

* [cherry-pick] maxout supports channel_last input (#20846)

* maxout support channel_last input, test=develop

* modified details of Input(X) and Attr(groups, axis) in doc, test=develop

* [cherry-pick] lrn supports channel_last input, test=develop (#20954)
上级 9f004548
...@@ -31,8 +31,9 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -31,8 +31,9 @@ class CropTensorOp : public framework::OperatorWithKernel {
"Input(X) of Op(crop_tensor) should not be null."); "Input(X) of Op(crop_tensor) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Op(crop_tensor) should not be null."); "Output(Out) of Op(crop_tensor) should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto offsets = ctx->Attrs().Get<std::vector<int>>("offsets");
if (ctx->HasInputs("ShapeTensor")) { if (ctx->HasInputs("ShapeTensor")) {
// top prority shape // top prority shape
auto inputs_name = ctx->Inputs("ShapeTensor"); auto inputs_name = ctx->Inputs("ShapeTensor");
...@@ -43,15 +44,19 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -43,15 +44,19 @@ class CropTensorOp : public framework::OperatorWithKernel {
"Op(fluid.layers.crop_tensor)."); "Op(fluid.layers.crop_tensor).");
auto out_dims = std::vector<int>(inputs_name.size(), -1); auto out_dims = std::vector<int>(inputs_name.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != -1) { if (shape[i] > 0) {
out_dims[i] = static_cast<int64_t>(shape[i]); out_dims[i] = static_cast<int64_t>(shape[i]);
} else {
if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
out_dims[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
}
} }
} }
ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
return; return;
} }
auto x_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Shape")) { if (ctx->HasInput("Shape")) {
auto shape_dim = ctx->GetInputDim("Shape"); auto shape_dim = ctx->GetInputDim("Shape");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -78,11 +83,17 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -78,11 +83,17 @@ class CropTensorOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(int64_t(shape.size()), x_dim.size(), PADDLE_ENFORCE_EQ(int64_t(shape.size()), x_dim.size(),
"Attr(shape)'size of Op(crop_tensor) should be equal to " "Attr(shape)'size of Op(crop_tensor) should be equal to "
"dimention size of input tensor."); "dimention size of input tensor.");
std::vector<int64_t> tensor_shape(shape.size()); std::vector<int64_t> out_shape(shape.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]); if (shape[i] > 0) {
out_shape[i] = static_cast<int64_t>(shape[i]);
} else {
if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
out_shape[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
}
}
} }
ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape)); ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
...@@ -294,8 +305,12 @@ REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad); ...@@ -294,8 +305,12 @@ REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
crop_tensor, crop_tensor,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, float>, ops::CropTensorKernel<paddle::platform::CPUDeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, double>); ops::CropTensorKernel<paddle::platform::CPUDeviceContext, double>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, int>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
crop_tensor_grad, crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, float>, ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, double>); ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -17,8 +17,12 @@ namespace ops = paddle::operators; ...@@ -17,8 +17,12 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
crop_tensor, crop_tensor,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, float>, ops::CropTensorKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, double>); ops::CropTensorKernel<paddle::platform::CUDADeviceContext, double>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, int>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
crop_tensor_grad, crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, float>, ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, double>); ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -50,29 +50,28 @@ inline std::vector<int> get_new_data( ...@@ -50,29 +50,28 @@ inline std::vector<int> get_new_data(
} }
static framework::DDim ValidateShape(const std::vector<int> shape, static framework::DDim ValidateShape(const std::vector<int> shape,
const std::vector<int> offsets,
const framework::DDim& in_dims) { const framework::DDim& in_dims) {
auto in_dim_size = in_dims.size(); auto in_dim_size = in_dims.size();
auto shape_size = shape.size(); auto shape_size = shape.size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dim_size, shape_size, in_dim_size, shape_size,
"Input(ShapeTensor)'s dimension size of Op(crop_tensor) should be equal " "Attr(shape)'s size of Op(crop_tensor) should be equal "
"to that of input tensor. " "to that of input Tensor. "
"Please check the Attr(shape)'s size of Op(fluid.layers.crop_tensor)."); "Please check the Attr(shape)'s size of Op(fluid.layers.crop_tensor).");
const int64_t unk_dim_val = -1;
int unk_dim_idx = -1;
std::vector<int64_t> output_shape(shape.size(), 0); std::vector<int64_t> output_shape(shape.size(), 0);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) { if (shape[i] <= 0 && in_dims[i] > 0) {
PADDLE_ENFORCE_EQ(unk_dim_idx, -1, PADDLE_ENFORCE_NE(
"Only one element of shape can be unknown."); shape[i], 0,
PADDLE_ENFORCE_EQ(i, 0, "Only the first element of shape can be -1."); "The element in Attr(shape) of Op(crop_tensor) should not be zero.");
unk_dim_idx = i; PADDLE_ENFORCE_EQ(shape[i], -1,
"When the element in Attr(shape) of Op(crop_tensor) is "
"negative, only -1 is supported.");
output_shape[i] = in_dims[i] - offsets[i];
} else { } else {
PADDLE_ENFORCE_GT(shape[i], 0, output_shape[i] = static_cast<int64_t>(shape[i]);
"Each element of shape must be greater than 0 "
"except the first element.");
} }
output_shape[i] = static_cast<int64_t>(shape[i]);
} }
return framework::make_ddim(output_shape); return framework::make_ddim(output_shape);
...@@ -164,21 +163,15 @@ void CropTensorFunction(const framework::ExecutionContext& context) { ...@@ -164,21 +163,15 @@ void CropTensorFunction(const framework::ExecutionContext& context) {
shape.push_back(out_dims[i]); shape.push_back(out_dims[i]);
} }
} }
out_dims = ValidateShape(shape, x->dims());
if (out_dims[0] == -1) {
out_dims[0] = x->dims()[0];
}
out->mutable_data<T>(out_dims, context.GetPlace());
auto x_stride = framework::stride(x->dims());
auto offsets = GetOffsets(context); auto offsets = GetOffsets(context);
int64_t offset = 0; out_dims = ValidateShape(shape, offsets, x->dims());
out->mutable_data<T>(out_dims, context.GetPlace());
for (size_t i = 0; i < offsets.size(); ++i) { for (size_t i = 0; i < offsets.size(); ++i) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
offsets[i] + shape[i], x_dims[i], offsets[i] + shape[i], x_dims[i],
"The sum of the Attr(offsets) and Attr(shape) of Op(crop_tensor) " "The sum of the Attr(offsets) and Attr(shape) of Op(crop_tensor) "
"should be less than or equal to corresponding input dimension size."); "should be less than or equal to corresponding input dimension size.");
offset += (x_stride[i] * offsets[i]);
} }
auto x_tensor = EigenTensor<T, D>::From(*x); auto x_tensor = EigenTensor<T, D>::From(*x);
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/lrn_op.h" #include "paddle/fluid/operators/lrn_op.h"
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -23,18 +25,41 @@ namespace paddle { ...@@ -23,18 +25,41 @@ namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T> template <typename T>
struct LRNFunctor<platform::CPUDeviceContext, T> { struct LRNFunctor<platform::CPUDeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx, void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out, const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n, framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) { T k, T alpha, T beta, const DataLayout data_layout) {
const T* idata = input.data<T>();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
T* odata = out->mutable_data<T>(place); math::Transpose<platform::CPUDeviceContext, T, 4> transpose;
T* mdata = mid->mutable_data<T>(place); auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
Tensor in_transpose, mid_transpose, out_transpose;
// if channel_last, transpose to channel_first
if (data_layout == DataLayout::kNHWC) {
auto in_dims = input.dims();
std::vector<int64_t> shape(
{in_dims[0], in_dims[3], in_dims[1], in_dims[2]});
in_transpose.mutable_data<T>(framework::make_ddim(shape), place);
mid_transpose.mutable_data<T>(framework::make_ddim(shape), place);
out_transpose.mutable_data<T>(framework::make_ddim(shape), place);
std::vector<int> axis = {0, 3, 1, 2};
transpose(dev_ctx, input, &in_transpose, axis);
} else {
in_transpose = input;
mid_transpose = *mid;
out_transpose = *out;
mid_transpose.mutable_data<T>(mid->dims(), place);
out_transpose.mutable_data<T>(out->dims(), place);
}
const T* idata = in_transpose.data<T>();
T* odata = out_transpose.data<T>();
T* mdata = mid_transpose.data<T>();
Tensor squared; Tensor squared;
T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place); T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place);
std::memset(sdata, 0, sizeof(T) * squared.numel()); std::memset(sdata, 0, sizeof(T) * squared.numel());
...@@ -67,6 +92,13 @@ struct LRNFunctor<platform::CPUDeviceContext, T> { ...@@ -67,6 +92,13 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
// compute the final output // compute the final output
blas.VPOW(mid->numel(), mdata, -beta, odata); blas.VPOW(mid->numel(), mdata, -beta, odata);
blas.VMUL(mid->numel(), odata, idata, odata); blas.VMUL(mid->numel(), odata, idata, odata);
// if channel_last, transpose the output(NCHW) to channel_last
if (data_layout == DataLayout::kNHWC) {
std::vector<int> axis = {0, 2, 3, 1};
transpose(dev_ctx, mid_transpose, mid, axis);
transpose(dev_ctx, out_transpose, out, axis);
}
} }
}; };
template struct LRNFunctor<platform::CPUDeviceContext, float>; template struct LRNFunctor<platform::CPUDeviceContext, float>;
...@@ -78,7 +110,7 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> { ...@@ -78,7 +110,7 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& x, const framework::Tensor& out, const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g, const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W, const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta) { int n, T alpha, T beta, const DataLayout data_layout) {
T ratio = -2 * alpha * beta; T ratio = -2 * alpha * beta;
auto x_g_e = framework::EigenVector<T>::Flatten(*x_g); auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
x_g_e = x_g_e.constant(0.0); x_g_e = x_g_e.constant(0.0);
...@@ -93,17 +125,17 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> { ...@@ -93,17 +125,17 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
const int end = start + n; const int end = start + n;
for (int m = 0; m < N; m++) { for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) { for (int i = 0; i < C; i++) {
auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}), auto offsets = Eigen::array<int, 4>({{m, i, 0, 0}});
Eigen::array<int, 4>({{1, 1, H, W}})); auto extents = Eigen::array<int, 4>({{1, 1, H, W}});
if (data_layout == DataLayout::kNHWC) {
auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}), offsets = Eigen::array<int, 4>({{m, 0, 0, i}});
Eigen::array<int, 4>({{1, 1, H, W}})); extents = Eigen::array<int, 4>({{1, H, W, 1}});
}
auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}), auto i_x = e_x.slice(offsets, extents);
Eigen::array<int, 4>({{1, 1, H, W}})); auto i_x_g = e_x_g.slice(offsets, extents);
auto i_out_g = e_out_g.slice(offsets, extents);
auto i_mid = e_mid.slice(offsets, extents);
i_x_g = i_mid.pow(-beta) * i_out_g; i_x_g = i_mid.pow(-beta) * i_out_g;
for (int c = start; c < end; c++) { for (int c = start; c < end; c++) {
...@@ -112,14 +144,14 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> { ...@@ -112,14 +144,14 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
continue; continue;
} }
auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}), if (data_layout != DataLayout::kNHWC) {
Eigen::array<int, 4>({{1, 1, H, W}})); offsets = Eigen::array<int, 4>({{m, ch, 0, 0}});
} else {
auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}), offsets = Eigen::array<int, 4>({{m, 0, 0, ch}});
Eigen::array<int, 4>({{1, 1, H, W}})); }
auto c_out = e_out.slice(offsets, extents);
auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}), auto c_mid = e_mid.slice(offsets, extents);
Eigen::array<int, 4>({{1, 1, H, W}})); auto c_out_g = e_out_g.slice(offsets, extents);
i_x_g += ratio * c_out_g * c_out * i_x / c_mid; i_x_g += ratio * c_out_g * c_out * i_x / c_mid;
} }
...@@ -156,9 +188,8 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -156,9 +188,8 @@ class LRNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
...@@ -242,8 +273,8 @@ $$ ...@@ -242,8 +273,8 @@ $$
Function implementation: Function implementation:
Inputs and outpus are in NCHW format, while input.shape.ndims() equals 4. Inputs and outpus are in NCHW or NHWC format, while input.shape.ndims() equals 4.
And dimensions 0 ~ 3 represent batch size, feature maps, rows, If NCHW, the dimensions 0 ~ 3 represent batch size, feature maps, rows,
and columns, respectively. and columns, respectively.
Input and Output in the formula above is for each map(i) of one image, and Input and Output in the formula above is for each map(i) of one image, and
...@@ -275,9 +306,8 @@ class LRNOpGrad : public framework::OperatorWithKernel { ...@@ -275,9 +306,8 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
......
...@@ -17,15 +17,20 @@ limitations under the License. */ ...@@ -17,15 +17,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using DataLayout = framework::DataLayout;
template <typename T> template <typename T>
__global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C, __global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C,
int H, int W, int size, T k, T alpha) { int H, int W, int size, T k, T alpha,
const DataLayout data_layout) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < img_size) { if (idx < img_size) {
const int w = idx % W; const int w = idx % W;
const int h = (idx / W) % H; const int h = (idx / W) % H;
const int n = idx / W / H; const int n = idx / W / H;
const int offset = (n * C * H + h) * W + w; const int offset =
(data_layout != DataLayout::kNHWC ? (n * C * H + h) * W + w
: ((n * H + h) * W + w) * C);
in += offset; in += offset;
mid += offset; mid += offset;
...@@ -37,15 +42,21 @@ __global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C, ...@@ -37,15 +42,21 @@ __global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C,
int index = 0; int index = 0;
while (index < C + post_pad) { while (index < C + post_pad) {
if (index < C) { if (index < C) {
T val = in[index * step]; int in_idx = (data_layout != DataLayout::kNHWC ? index * step : index);
T val = in[in_idx];
accum += val * val; accum += val * val;
} }
if (index >= size) { if (index >= size) {
T val = in[(index - size) * step]; int in_idx = (data_layout != DataLayout::kNHWC ? (index - size) * step
: index - size);
T val = in[in_idx];
accum -= val * val; accum -= val * val;
} }
if (index >= post_pad) { if (index >= post_pad) {
mid[(index - post_pad) * step] = k + accum * alpha; int mid_idx =
(data_layout != DataLayout::kNHWC ? (index - post_pad) * step
: index - post_pad);
mid[mid_idx] = k + accum * alpha;
} }
++index; ++index;
} }
...@@ -64,14 +75,14 @@ __global__ void KeCMRNormOutput(int input_size, const T* in, const T* mid, ...@@ -64,14 +75,14 @@ __global__ void KeCMRNormOutput(int input_size, const T* in, const T* mid,
template <typename T> template <typename T>
void CrossMapNormal(const framework::ExecutionContext& ctx, const T* inputs, void CrossMapNormal(const framework::ExecutionContext& ctx, const T* inputs,
T* outputs, T* mid, int N, int C, int H, int W, int n, T k, T* outputs, T* mid, int N, int C, int H, int W, int n, T k,
T alpha, T beta) { T alpha, T beta, const DataLayout data_layout) {
int img_size = N * H * W; int img_size = N * H * W;
const int block_size = 1024; const int block_size = 1024;
int grid_size = (img_size + block_size - 1) / block_size; int grid_size = (img_size + block_size - 1) / block_size;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
KeCMRNormFillScale<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>( KeCMRNormFillScale<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
img_size, inputs, mid, C, H, W, n, k, alpha); img_size, inputs, mid, C, H, W, n, k, alpha, data_layout);
int input_size = N * H * W * C; int input_size = N * H * W * C;
grid_size = (input_size + block_size - 1) / block_size; grid_size = (input_size + block_size - 1) / block_size;
...@@ -84,10 +95,11 @@ struct LRNFunctor<platform::CUDADeviceContext, T> { ...@@ -84,10 +95,11 @@ struct LRNFunctor<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx, void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out, const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n, framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) { T k, T alpha, T beta, const DataLayout data_layout) {
CrossMapNormal<T>( CrossMapNormal<T>(ctx, input.data<T>(),
ctx, input.data<T>(), out->mutable_data<T>(ctx.GetPlace()), out->mutable_data<T>(ctx.GetPlace()),
mid->mutable_data<T>(ctx.GetPlace()), N, C, H, W, n, k, alpha, beta); mid->mutable_data<T>(ctx.GetPlace()), N, C, H, W, n, k,
alpha, beta, data_layout);
} }
}; };
...@@ -97,14 +109,16 @@ template struct LRNFunctor<platform::CUDADeviceContext, double>; ...@@ -97,14 +109,16 @@ template struct LRNFunctor<platform::CUDADeviceContext, double>;
template <typename T> template <typename T>
__global__ void KeCMRNormDiff(int img_size, const T* x, const T* out, __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
const T* mid, T* x_g, const T* out_g, int C, const T* mid, T* x_g, const T* out_g, int C,
int H, int W, int size, T negative_beta, int H, int W, int size, T negative_beta, T ratio,
T ratio) { const DataLayout data_layout) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < img_size) { if (idx < img_size) {
const int w = idx % W; const int w = idx % W;
const int h = (idx / W) % H; const int h = (idx / W) % H;
const int n = idx / W / H; const int n = idx / W / H;
const int offset = (n * C * H + h) * W + w; const int offset =
(data_layout != DataLayout::kNHWC ? (n * C * H + h) * W + w
: ((n * H + h) * W + w) * C);
x += offset; x += offset;
out += offset; out += offset;
mid += offset; mid += offset;
...@@ -120,18 +134,20 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out, ...@@ -120,18 +134,20 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
// TODO(gongwb): optimize this with thread shared array. // TODO(gongwb): optimize this with thread shared array.
while (index < C + post_pad) { while (index < C + post_pad) {
if (index < C) { if (index < C) {
x_g[index * step] = 0.0; int idx = (data_layout != DataLayout::kNHWC ? index * step : index);
accum += out_g[index * step] * out[index * step] / mid[index * step]; x_g[idx] = 0.0;
accum += out_g[idx] * out[idx] / mid[idx];
} }
if (index >= size) { if (index >= size) {
accum -= out_g[(index - size) * step] * out[(index - size) * step] / int idx = (data_layout != DataLayout::kNHWC ? (index - size) * step
mid[(index - size) * step]; : index - size);
accum -= out_g[idx] * out[idx] / mid[idx];
} }
if (index >= post_pad) { if (index >= post_pad) {
x_g[(index - post_pad) * step] += int idx = (data_layout != DataLayout::kNHWC ? (index - post_pad) * step
out_g[(index - post_pad) * step] * : index - post_pad);
pow(mid[(index - post_pad) * step], negative_beta) - x_g[idx] +=
ratio * x[(index - post_pad) * step] * accum; out_g[idx] * pow(mid[idx], negative_beta) - ratio * x[idx] * accum;
} }
++index; ++index;
} }
...@@ -141,7 +157,8 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out, ...@@ -141,7 +157,8 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
template <typename T> template <typename T>
void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x, void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x,
const T* out, const T* mid, T* x_g, const T* out_g, const T* out, const T* mid, T* x_g, const T* out_g,
int N, int C, int H, int W, int n, T alpha, T beta) { int N, int C, int H, int W, int n, T alpha, T beta,
const DataLayout data_layout) {
int img_size = N * H * W; int img_size = N * H * W;
const int block_size = 1024; const int block_size = 1024;
...@@ -149,8 +166,8 @@ void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x, ...@@ -149,8 +166,8 @@ void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x,
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
KeCMRNormDiff<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>( KeCMRNormDiff<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta, img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta, 2.0f * alpha * beta,
2.0f * alpha * beta); data_layout);
} }
template <typename T> template <typename T>
...@@ -159,10 +176,10 @@ struct LRNGradFunctor<platform::CUDADeviceContext, T> { ...@@ -159,10 +176,10 @@ struct LRNGradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& x, const framework::Tensor& out, const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g, const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W, const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta) { int n, T alpha, T beta, const DataLayout data_layout) {
CrossMapNormalGrad<T>(ctx, x.data<T>(), out.data<T>(), mid.data<T>(), CrossMapNormalGrad<T>(ctx, x.data<T>(), out.data<T>(), mid.data<T>(),
x_g->mutable_data<T>(ctx.GetPlace()), out_g.data<T>(), x_g->mutable_data<T>(ctx.GetPlace()), out_g.data<T>(),
N, C, H, W, n, alpha, beta); N, C, H, W, n, alpha, beta, data_layout);
} }
}; };
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -21,12 +23,15 @@ limitations under the License. */ ...@@ -21,12 +23,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using DataLayout = framework::DataLayout;
template <typename place, typename T> template <typename place, typename T>
struct LRNFunctor { struct LRNFunctor {
void operator()(const framework::ExecutionContext& ctx, void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out, const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n, framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta); T k, T alpha, T beta,
const DataLayout data_layout = DataLayout::kAnyLayout);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -42,11 +47,14 @@ class LRNKernel : public framework::OpKernel<T> { ...@@ -42,11 +47,14 @@ class LRNKernel : public framework::OpKernel<T> {
const Tensor& x = *ctx.Input<Tensor>("X"); const Tensor& x = *ctx.Input<Tensor>("X");
auto x_dims = x.dims(); auto x_dims = x.dims();
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
// NCHW // NCHW
int N = x_dims[0]; int N = x_dims[0];
int C = x_dims[1]; int C = (data_layout != DataLayout::kNHWC ? x_dims[1] : x_dims[3]);
int H = x_dims[2]; int H = (data_layout != DataLayout::kNHWC ? x_dims[2] : x_dims[1]);
int W = x_dims[3]; int W = (data_layout != DataLayout::kNHWC ? x_dims[3] : x_dims[2]);
Tensor* out = ctx.Output<Tensor>("Out"); Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
...@@ -65,7 +73,7 @@ class LRNKernel : public framework::OpKernel<T> { ...@@ -65,7 +73,7 @@ class LRNKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0"); PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");
LRNFunctor<DeviceContext, T> f; LRNFunctor<DeviceContext, T> f;
f(ctx, x, out, mid, N, C, H, W, n, k, alpha, beta); f(ctx, x, out, mid, N, C, H, W, n, k, alpha, beta, data_layout);
} }
}; };
...@@ -75,7 +83,8 @@ struct LRNGradFunctor { ...@@ -75,7 +83,8 @@ struct LRNGradFunctor {
const framework::Tensor& x, const framework::Tensor& out, const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g, const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W, const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta); int n, T alpha, T beta,
const DataLayout data_layout = DataLayout::kAnyLayout);
}; };
/** /**
...@@ -106,15 +115,18 @@ class LRNGradKernel : public framework::OpKernel<T> { ...@@ -106,15 +115,18 @@ class LRNGradKernel : public framework::OpKernel<T> {
const Tensor& out = *ctx.Input<Tensor>("Out"); const Tensor& out = *ctx.Input<Tensor>("Out");
const Tensor& out_g = *ctx.Input<Tensor>(framework::GradVarName("Out")); const Tensor& out_g = *ctx.Input<Tensor>(framework::GradVarName("Out"));
const Tensor& mid = *ctx.Input<Tensor>("MidOut"); const Tensor& mid = *ctx.Input<Tensor>("MidOut");
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
auto x_g = ctx.Output<Tensor>(framework::GradVarName("X")); auto x_g = ctx.Output<Tensor>(framework::GradVarName("X"));
x_g->mutable_data<T>(ctx.GetPlace()); x_g->mutable_data<T>(ctx.GetPlace());
auto x_dims = x.dims(); auto x_dims = x.dims();
int N = x_dims[0]; int N = x_dims[0];
int C = x_dims[1]; int C = (data_layout != DataLayout::kNHWC ? x_dims[1] : x_dims[3]);
int H = x_dims[2]; int H = (data_layout != DataLayout::kNHWC ? x_dims[2] : x_dims[1]);
int W = x_dims[3]; int W = (data_layout != DataLayout::kNHWC ? x_dims[3] : x_dims[2]);
int n = ctx.Attr<int>("n"); int n = ctx.Attr<int>("n");
T alpha = ctx.Attr<T>("alpha"); T alpha = ctx.Attr<T>("alpha");
...@@ -125,7 +137,7 @@ class LRNGradKernel : public framework::OpKernel<T> { ...@@ -125,7 +137,7 @@ class LRNGradKernel : public framework::OpKernel<T> {
"is_test attribute should be set to False in training phase."); "is_test attribute should be set to False in training phase.");
LRNGradFunctor<DeviceContext, T> f; LRNGradFunctor<DeviceContext, T> f;
f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta); f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta, data_layout);
} }
}; };
......
...@@ -18,35 +18,45 @@ namespace paddle { ...@@ -18,35 +18,45 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// All tensors are in NCHW format, and the groups must be greater than 1 // All tensors are in NCHW or NHWC format, and the groups must be greater than 1
template <typename T> template <typename T>
class MaxOutFunctor<platform::CPUDeviceContext, T> { class MaxOutFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* output, const framework::Tensor& input, framework::Tensor* output,
int groups) { const int groups, const int axis) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = input.dims()[3]; const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[1]; const int output_channels = output->dims()[axis];
int fea_size = input_height * input_width; int fea_size = input_height * input_width;
// c_size means the output size of each sample // c_size means the output size of each sample
int c_size = fea_size * output_channels; int c_size = fea_size * output_channels;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i; int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c; int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) { for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX); T ele = static_cast<T>(-FLT_MAX);
int input_idx, output_idx;
for (int ph = 0; ph < groups; ++ph) { for (int ph = 0; ph < groups; ++ph) {
T x = input_data[(new_bindex + new_cindex) * groups + if (axis == 1) {
ph * fea_size + f]; input_idx =
(new_bindex + new_cindex) * groups + ph * fea_size + f;
} else {
input_idx = (new_bindex + f * output_channels + c) * groups + ph;
}
T x = input_data[input_idx];
ele = ele > x ? ele : x; ele = ele > x ? ele : x;
} }
output_data[(new_bindex + new_cindex + f)] = ele; if (axis == 1) {
output_idx = new_bindex + new_cindex + f;
} else {
output_idx = new_bindex + f * output_channels + c;
}
output_data[output_idx] = ele;
} }
} }
} }
...@@ -59,11 +69,12 @@ class MaxOutGradFunctor<platform::CPUDeviceContext, T> { ...@@ -59,11 +69,12 @@ class MaxOutGradFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, int groups) { const framework::Tensor& output_grad, const int groups,
const int axis) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = input.dims()[3]; const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[1]; const int output_channels = output.dims()[axis];
int fea_size = input_height * input_width; int fea_size = input_height * input_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
...@@ -75,11 +86,18 @@ class MaxOutGradFunctor<platform::CPUDeviceContext, T> { ...@@ -75,11 +86,18 @@ class MaxOutGradFunctor<platform::CPUDeviceContext, T> {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c; int clen = fea_size * c;
for (int f = 0; f < fea_size; ++f) { for (int f = 0; f < fea_size; ++f) {
int input_idx0 = (blen + clen) * groups + f; int input_idx0, output_idx;
bool continue_match = true; bool continue_match = true;
int output_idx = blen + clen + f; if (axis == 1) {
input_idx0 = (blen + clen) * groups + f;
output_idx = blen + clen + f;
} else {
input_idx0 = (blen + f * output_channels + c) * groups;
output_idx = blen + f * output_channels + c;
}
for (int g = 0; g < groups && continue_match; ++g) { for (int g = 0; g < groups && continue_match; ++g) {
int input_idx = input_idx0 + fea_size * g; int idx_offset = (axis == 1 ? fea_size * g : g);
int input_idx = input_idx0 + idx_offset;
if (input_data[input_idx] == output_data[output_idx]) { if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx]; input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false; continue_match = false;
......
...@@ -22,8 +22,8 @@ namespace math { ...@@ -22,8 +22,8 @@ namespace math {
template <typename T> template <typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data, __global__ void KernelMaxOut(const int nthreads, const T* input_data,
const int channels, const int input_height, const int channels, const int input_height,
const int input_width, int groups, const int input_width, const int groups,
T* output_data) { const int axis, T* output_data) {
const int size = input_height * input_width * channels / groups; const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width; const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -31,13 +31,22 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data, ...@@ -31,13 +31,22 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
for (int i = index; i < nthreads; i += offset) { for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size; int batch_idx = i / size;
int batch_offset = i % size; int batch_offset = i % size;
int channel_idx = batch_offset / feat_len; int channel_idx, feat_idx, data_idx;
int feat_idx = batch_offset % feat_len; if (axis == 1) {
int data_idx = channel_idx = batch_offset / feat_len;
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; feat_idx = batch_offset % feat_len;
data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
} else {
channel_idx = batch_offset % channels;
feat_idx = batch_offset / channels;
data_idx =
(batch_idx * size + feat_idx * channels + channel_idx) * groups;
}
T ele = static_cast<T>(-FLT_MAX); T ele = static_cast<T>(-FLT_MAX);
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
T x = input_data[data_idx + g * feat_len]; int idx_offset = (axis == 1 ? g * feat_len : g);
T x = input_data[data_idx + idx_offset];
ele = ele > x ? ele : x; ele = ele > x ? ele : x;
} }
output_data[i] = ele; output_data[i] = ele;
...@@ -48,7 +57,7 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, ...@@ -48,7 +57,7 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
const T* output_data, const T* output_grad, const T* output_data, const T* output_grad,
T* input_grad, const int channels, T* input_grad, const int channels,
const int input_height, const int input_width, const int input_height, const int input_width,
int groups) { const int groups, const int axis) {
const int size = input_height * input_width * channels / groups; const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width; const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -56,15 +65,24 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, ...@@ -56,15 +65,24 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
for (int i = index; i < nthreads; i += offset) { for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size; int batch_idx = i / size;
int batch_offset = i % size; int batch_offset = i % size;
int channel_idx = batch_offset / feat_len; int channel_idx, feat_idx, data_idx;
int feat_idx = batch_offset % feat_len; if (axis == 1) {
int data_idx = channel_idx = batch_offset / feat_len;
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; feat_idx = batch_offset % feat_len;
data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
} else {
channel_idx = batch_offset % channels;
feat_idx = batch_offset / channels;
data_idx =
(batch_idx * size + feat_idx * channels + channel_idx) * groups;
}
int max_index = -1; int max_index = -1;
bool continue_match = true; bool continue_match = true;
for (int g = 0; g < groups && continue_match; ++g) { for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[i]) { int idx_offset = (axis == 1 ? g * feat_len : g);
max_index = data_idx + g * feat_len; if (input_data[data_idx + idx_offset] == output_data[i]) {
max_index = data_idx + idx_offset;
continue_match = false; continue_match = false;
break; break;
} }
...@@ -75,21 +93,19 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, ...@@ -75,21 +93,19 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
} }
} }
/* /*
* All tensors are in NCHW format. * All tensors are in NCHW or NHWC format.
*/ */
template <typename T> template <typename T>
class MaxOutFunctor<platform::CUDADeviceContext, T> { class MaxOutFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* output, const framework::Tensor& input, framework::Tensor* output,
int groups) { const int groups, const int axis) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1]; const int input_channels = input.dims()[axis];
const int input_height = input.dims()[2]; const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = input.dims()[3]; const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[1]; const int output_channels = output->dims()[axis];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
...@@ -100,11 +116,11 @@ class MaxOutFunctor<platform::CUDADeviceContext, T> { ...@@ -100,11 +116,11 @@ class MaxOutFunctor<platform::CUDADeviceContext, T> {
KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>( KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width, groups, nthreads, input_data, input_channels, input_height, input_width, groups,
output_data); axis, output_data);
} }
}; };
/* /*
* All tensors are in NCHW format. * All tensors are in NCHW or NHWC format.
*/ */
template <typename T> template <typename T>
class MaxOutGradFunctor<platform::CUDADeviceContext, T> { class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
...@@ -112,14 +128,13 @@ class MaxOutGradFunctor<platform::CUDADeviceContext, T> { ...@@ -112,14 +128,13 @@ class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, int groups) { const framework::Tensor& output_grad, const int groups,
const int axis) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1]; const int input_channels = input.dims()[axis];
const int input_height = input.dims()[2]; const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = input.dims()[3]; const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[1]; const int output_channels = output.dims()[axis];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
...@@ -132,7 +147,7 @@ class MaxOutGradFunctor<platform::CUDADeviceContext, T> { ...@@ -132,7 +147,7 @@ class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>( KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data, nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups); input_channels, input_height, input_width, groups, axis);
} }
}; };
......
...@@ -26,7 +26,8 @@ template <typename DeviceContext, typename T> ...@@ -26,7 +26,8 @@ template <typename DeviceContext, typename T>
class MaxOutFunctor { class MaxOutFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* output, int groups); framework::Tensor* output, const int groups,
const int axis = 1);
}; };
template <typename DeviceContext, class T> template <typename DeviceContext, class T>
...@@ -35,7 +36,8 @@ class MaxOutGradFunctor { ...@@ -35,7 +36,8 @@ class MaxOutGradFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* input_grad, framework::Tensor* input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, int groups); const framework::Tensor& output_grad, const int groups,
const int axis = 1);
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -23,25 +23,27 @@ using framework::Tensor; ...@@ -23,25 +23,27 @@ using framework::Tensor;
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput( AddInput("X",
"X", "A 4-D Tensor with data type of float32 or float64. "
"(Tensor) The input tensor of maxout operator with data type of " "The data format is NCHW or NHWC. Where N is "
"float32. The format of input tensor is NCHW. Where N is batch size," "batch size, C is the number of channels, "
" C is the number of channels, H and W is the height and width of " "H and W is the height and width of "
"feature."); "feature. ");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output tensor of maxout operator." "A 4-D Tensor with same data type and data format "
"The data type is float32." "with input Tensor. ");
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
AddAttr<int>( AddAttr<int>(
"groups", "groups",
"(int)," "Specifies how many groups the input tensor will be split into "
"Specifies how many groups the input tensor will be split" "at the channel dimension. And the number of output channel is "
"in the channel dimension. And the number of output channel is " "the number of channels divided by groups. ");
"the number of channels divided by groups."); AddAttr<int>(
"axis",
"Specifies the index of channel dimension where maxout will "
"be performed. It should be 1 when data format is NCHW, -1 or 3 "
"when data format is NHWC. "
"Default: 1. ")
.SetDefault(1);
AddComment(R"DOC( AddComment(R"DOC(
MaxOut Operator. MaxOut Operator.
...@@ -70,17 +72,19 @@ class MaxOutOp : public framework::OperatorWithKernel { ...@@ -70,17 +72,19 @@ class MaxOutOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of MaxoutOpshould not be null."); "Input(X) of MaxoutOpshould not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of MaxoutOp should not be null."); "Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
int axis = ctx->Attrs().Get<int>("axis");
// check groups > 1 // check groups > 1
PADDLE_ENFORCE_GT(groups, 1, "groups should be larger than 1 in maxoutop"); PADDLE_ENFORCE_GT(groups, 1,
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups}); "Attr(groups) of Op(maxout) should be larger than 1.");
output_shape.push_back(in_x_dims[2]); std::vector<int64_t> output_shape(
output_shape.push_back(in_x_dims[3]); {in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
output_shape[axis] = in_x_dims[axis] / groups;
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
} }
}; };
......
...@@ -30,10 +30,11 @@ class MaxOutKernel : public framework::OpKernel<T> { ...@@ -30,10 +30,11 @@ class MaxOutKernel : public framework::OpKernel<T> {
const Tensor* in_x = context.Input<Tensor>("X"); const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out"); Tensor* out = context.Output<Tensor>("Out");
int groups = context.template Attr<int>("groups"); int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis");
math::MaxOutFunctor<DeviceContext, T> maxout_forward; math::MaxOutFunctor<DeviceContext, T> maxout_forward;
maxout_forward(context.template device_context<DeviceContext>(), *in_x, out, maxout_forward(context.template device_context<DeviceContext>(), *in_x, out,
groups); groups, axis);
} }
}; };
...@@ -47,13 +48,15 @@ class MaxOutGradKernel : public framework::OpKernel<T> { ...@@ -47,13 +48,15 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
context.Input<Tensor>(framework::GradVarName("Out")); context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int groups = context.template Attr<int>("groups"); int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis");
auto& device_ctx = context.template device_context<DeviceContext>(); auto& device_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero; math::SetConstant<DeviceContext, T> zero;
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0)); zero(device_ctx, in_x_grad, static_cast<T>(0.0));
math::MaxOutGradFunctor<DeviceContext, T> maxout_backward; math::MaxOutGradFunctor<DeviceContext, T> maxout_backward;
maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups); maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups,
axis);
} }
} }
}; };
......
...@@ -9334,7 +9334,8 @@ def lod_append(x, level): ...@@ -9334,7 +9334,8 @@ def lod_append(x, level):
return out return out
def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None,
data_format='NCHW'):
""" """
This operator implements the Local Response Normalization Layer. This operator implements the Local Response Normalization Layer.
This layer performs a type of "lateral inhibition" by normalizing over local input regions. This layer performs a type of "lateral inhibition" by normalizing over local input regions.
...@@ -9355,13 +9356,18 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): ...@@ -9355,13 +9356,18 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
Args: Args:
input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W], where N is the batch size, C is the input channel, H is Height, W is weight. The data type is float32. The rank of this tensor must be 4, otherwise it will raise ValueError. input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W] or [N, H, W, C],
where N is the batch size, C is the input channel, H is Height, W is weight. The data
type is float32. The rank of this tensor must be 4, otherwise it will raise ValueError.
n (int, optional): The number of channels to sum over. Default: 5 n (int, optional): The number of channels to sum over. Default: 5
k (float, optional): An offset, positive. Default: 1.0 k (float, optional): An offset, positive. Default: 1.0
alpha (float, optional): The scaling parameter, positive. Default:1e-4 alpha (float, optional): The scaling parameter, positive. Default:1e-4
beta (float, optional): The exponent, positive. Default:0.75 beta (float, optional): The exponent, positive. Default:0.75
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`
data_format(str, optional): The data format of the input and output data. An optional string
from: `"NCHW"`, `"NHWC"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. Default: 'NCHW'.
Returns: Returns:
Variable: A tensor variable storing the transformation result with the same shape and data type as input. Variable: A tensor variable storing the transformation result with the same shape and data type as input.
...@@ -9384,8 +9390,12 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): ...@@ -9384,8 +9390,12 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
if dims != 4: if dims != 4:
raise ValueError( raise ValueError(
"dims of input must be 4(not %d), and it's order must be NCHW" % "Input's dimension size of Op(lrn) must be 4, but received %d." %
(dims)) (dims))
if data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Attr(data_format) of Op(lrn) got wrong value: received " +
data_format + " but only NCHW or NHWC supported.")
mid_out = helper.create_variable_for_type_inference( mid_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
...@@ -9397,10 +9407,13 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): ...@@ -9397,10 +9407,13 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
"Out": lrn_out, "Out": lrn_out,
"MidOut": mid_out, "MidOut": mid_out,
}, },
attrs={"n": n, attrs={
"k": k, "n": n,
"alpha": alpha, "k": k,
"beta": beta}) "alpha": alpha,
"beta": beta,
"data_format": data_format
})
return lrn_out return lrn_out
...@@ -11547,7 +11560,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11547,7 +11560,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
* Case 1 (input is a 2-D Tensor): * Case 1 (input is a 2-D Tensor):
Input: Input:
X.shape = [3. 5] X.shape = [3, 5]
X.data = [[0, 1, 2, 0, 0], X.data = [[0, 1, 2, 0, 0],
[0, 3, 4, 0, 0], [0, 3, 4, 0, 0],
[0, 0, 0, 0, 0]] [0, 0, 0, 0, 0]]
...@@ -11555,8 +11568,9 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11555,8 +11568,9 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
shape = [2, 2] shape = [2, 2]
offsets = [0, 1] offsets = [0, 1]
Output: Output:
Out = [[1, 2], Out.shape = [2, 2]
[3, 4]] Out.data = [[1, 2],
[3, 4]]
* Case 2 (input is a 3-D Tensor): * Case 2 (input is a 3-D Tensor):
Input: Input:
X.shape = [2, 3, 4] X.shape = [2, 3, 4]
...@@ -11567,24 +11581,23 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11567,24 +11581,23 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
[0, 6, 7, 8], [0, 6, 7, 8],
[0, 0, 0, 0]]] [0, 0, 0, 0]]]
Parameters: Parameters:
shape = [2, 2, 3] shape = [2, 2, -1]
offsets = [0, 0, 1] offsets = [0, 0, 1]
Output: Output:
Out = [[[1, 2, 3], Out.shape = [2, 2, 3]
[5, 6, 7]], Out.data = [[[1, 2, 3],
[[3, 4, 5], [5, 6, 7]],
[6, 7, 8]]] [[3, 4, 5],
[6, 7, 8]]]
Parameters: Parameters:
x (Variable): 1-D to 6-D Tensor, the data type is float32 or float64. x (Variable): 1-D to 6-D Tensor, the data type is float32, float64, int32 or int64.
shape (list|tuple|Variable): The output shape is specified shape (list|tuple|Variable): The output shape is specified
by `shape`. Its data type is int32. If a list/tuple, it's length must be by `shape`. Its data type is int32. If a list/tuple, it's length must be
the same as the dimension size of `x`. If a Variable, it shoule be a 1-D Tensor. the same as the dimension size of `x`. If a Variable, it shoule be a 1-D Tensor.
When it is a list, each element can be an integer or a Tensor of shape: [1]. When it is a list, each element can be an integer or a Tensor of shape: [1].
If Variable contained, it is suitable for the case that the shape may If Variable contained, it is suitable for the case that the shape may
be changed each iteration. Only the first element of list/tuple can be be changed each iteration.
set to -1, it means that the first dimension's size of the output is the same
as the input.
offsets (list|tuple|Variable, optional): Specifies the cropping offsets (list|tuple|Variable, optional): Specifies the cropping
offsets at each dimension. Its data type is int32. If a list/tuple, it's length offsets at each dimension. Its data type is int32. If a list/tuple, it's length
must be the same as the dimension size of `x`. If a Variable, it shoule be a 1-D must be the same as the dimension size of `x`. If a Variable, it shoule be a 1-D
...@@ -11598,8 +11611,12 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11598,8 +11611,12 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
Variable: The cropped Tensor has same data type with `x`. Variable: The cropped Tensor has same data type with `x`.
Raises: Raises:
ValueError: If shape is not a list, tuple or Variable. TypeError: If the data type of `x` is not in: float32, float64, int32, int64.
ValueError: If offsets is not None and not a list, tuple or Variable. TypeError: If `shape` is not a list, tuple or Variable.
TypeError: If the data type of `shape` is not int32.
TypeError: If `offsets` is not None and not a list, tuple or Variable.
TypeError: If the data type of `offsets` is not int32.
ValueError: If the element in `offsets` is less than zero.
Examples: Examples:
...@@ -11615,7 +11632,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11615,7 +11632,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
# crop0.shape = [-1, -1, -1], it means crop0.shape[0] = x.shape[0] in runtime. # crop0.shape = [-1, -1, -1], it means crop0.shape[0] = x.shape[0] in runtime.
# or shape is a list in which each element is a constant # or shape is a list in which each element is a constant
crop1 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3]) crop1 = fluid.layers.crop_tensor(x, shape=[-1, -1, 3], offsets=[0, 1, 0])
# crop1.shape = [-1, 2, 3] # crop1.shape = [-1, 2, 3]
# or shape is a list in which each element is a constant or Variable # or shape is a list in which each element is a constant or Variable
...@@ -11637,70 +11654,98 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11637,70 +11654,98 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
""" """
helper = LayerHelper('crop_tensor', **locals()) helper = LayerHelper('crop_tensor', **locals())
if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']:
raise TypeError(
"Input(x)'s dtype of Op(crop_tensor) must be float32, float64, int32 or int64, "
"but received %s." % (convert_dtype(x.dtype)))
if not (isinstance(shape, list) or isinstance(shape, tuple) or \ if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)): isinstance(shape, Variable)):
raise ValueError("The shape should be a list, tuple or Variable.") raise TypeError(
"Attr(shape) of Op(crop_tensor) should be a list, tuple or Variable."
)
if offsets is None: if offsets is None:
offsets = [0] * len(x.shape) offsets = [0] * len(x.shape)
if not (isinstance(offsets, list) or isinstance(offsets, tuple) or \ if not (isinstance(offsets, list) or isinstance(offsets, tuple) or \
isinstance(offsets, Variable)): isinstance(offsets, Variable)):
raise ValueError("The offsets should be a list, tuple or Variable.") raise TypeError(
"Attr(offsets) of Op(crop_tensor) should be a list, tuple or Variable."
)
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x} ipts = {'X': x}
attrs = {} attrs = {}
def contain_var(input_list): def _contain_var(input_list):
for ele in input_list: for ele in input_list:
if isinstance(ele, Variable): if isinstance(ele, Variable):
return True return True
return False return False
def _attr_shape_check(shape_val):
if not isinstance(shape_val, int):
raise TypeError(
"Attr(shape)'s dtype of Op(crop_tensor) should be int32, but received: %s."
% type(shape_val))
if shape_val == 0:
raise ValueError(
"Attr(shape) of Op(crop_tensor) should not be zero, but received: %s."
% str(shape_val))
if shape_val < -1:
raise ValueError(
"When the element in Attr(shape) of Op(crop_tensor) is negative, only -1 is supported, but received: %s."
% str(shape_val))
def _attr_offsets_check(offset_val):
if not isinstance(offset_val, int):
raise TypeError(
"Attr(offsets)'s dtype of Op(crop_tensor) should be int32, but received: %s."
% type(offset_val))
if offset_val < 0:
raise ValueError(
"Attr(offsets) of Op(crop_tensor) should be greater or equal to zero, but received: %s."
% str(offset_val))
if isinstance(offsets, Variable): if isinstance(offsets, Variable):
offsets.stop_gradient = True offsets.stop_gradient = True
ipts['Offsets'] = offsets ipts['Offsets'] = offsets
elif contain_var(offsets): attrs['offsets'] = [-1] * len(x.shape)
elif _contain_var(offsets):
new_offsets_tensor = [] new_offsets_tensor = []
offsets_attr = []
for dim in offsets: for dim in offsets:
if isinstance(dim, Variable): if isinstance(dim, Variable):
dim.stop_gradient = True dim.stop_gradient = True
new_offsets_tensor.append(dim) new_offsets_tensor.append(dim)
offsets_attr.append(-1)
else: else:
assert (isinstance(dim, int)) _attr_offsets_check(dim)
assert dim >= 0, ("offsets should be greater or equal to zero.")
temp_out = helper.create_variable_for_type_inference('int32') temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out) fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_offsets_tensor.append(temp_out) new_offsets_tensor.append(temp_out)
offsets_attr.append(dim)
ipts['OffsetsTensor'] = new_offsets_tensor ipts['OffsetsTensor'] = new_offsets_tensor
attrs['offsets'] = offsets_attr
else: else:
for offset in offsets:
_attr_offsets_check(offset)
attrs['offsets'] = offsets attrs['offsets'] = offsets
unk_dim_idx = -1
if isinstance(shape, Variable): if isinstance(shape, Variable):
shape.stop_gradient = True shape.stop_gradient = True
ipts['Shape'] = shape ipts['Shape'] = shape
elif contain_var(shape): elif _contain_var(shape):
new_shape_tensor = [] new_shape_tensor = []
shape_attr = [] shape_attr = []
for dim_idx, dim_size in enumerate(shape): for dim_size in shape:
if isinstance(dim_size, Variable): if isinstance(dim_size, Variable):
dim_size.stop_gradient = True dim_size.stop_gradient = True
new_shape_tensor.append(dim_size) new_shape_tensor.append(dim_size)
shape_attr.append(-1) shape_attr.append(0)
else: else:
assert (isinstance(dim_size, int)) _attr_shape_check(dim_size)
if dim_size == -1:
assert unk_dim_idx == -1, (
"Only one element in shape can be unknown.")
assert dim_idx == 0, (
"Only the first element in shape can be -1.")
unk_dim_idx = dim_idx
else:
assert dim_size > 0, (
"Each dimension size given in shape must be greater than zero."
)
temp_out = helper.create_variable_for_type_inference('int32') temp_out = helper.create_variable_for_type_inference('int32')
fill_constant( fill_constant(
[1], 'int32', dim_size, force_cpu=True, out=temp_out) [1], 'int32', dim_size, force_cpu=True, out=temp_out)
...@@ -11709,6 +11754,8 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -11709,6 +11754,8 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
ipts['ShapeTensor'] = new_shape_tensor ipts['ShapeTensor'] = new_shape_tensor
attrs['shape'] = shape_attr attrs['shape'] = shape_attr
else: else:
for dim_size in shape:
_attr_shape_check(dim_size)
attrs['shape'] = shape attrs['shape'] = shape
helper.append_op( helper.append_op(
...@@ -15195,22 +15242,23 @@ def sigmoid_cross_entropy_with_logits(x, ...@@ -15195,22 +15242,23 @@ def sigmoid_cross_entropy_with_logits(x,
@templatedoc() @templatedoc()
def maxout(x, groups, name=None): def maxout(x, groups, name=None, axis=1):
""" """
${comment} ${comment}
Args: Args:
x(${x_type}): ${x_comment} x(${x_type}): ${x_comment}
groups(${groups_type}): ${groups_comment} groups(int): ${groups_comment}
axis(int, optional): ${axis_comment}
name(str, optional): For detailed information, please refer name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and to :ref:`api_guide_Name`. Usually name is no need to set and
None by default. None by default.
Returns: Returns:
Variable: Variable: ${out_comment}
out(${out_type}): ${out_comment}
Raises:
ValueError: If `axis` is not 1, -1 or 3.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -15223,6 +15271,12 @@ def maxout(x, groups, name=None): ...@@ -15223,6 +15271,12 @@ def maxout(x, groups, name=None):
out = fluid.layers.maxout(input, groups=2) out = fluid.layers.maxout(input, groups=2)
""" """
helper = LayerHelper("maxout", **locals()) helper = LayerHelper("maxout", **locals())
if axis not in [1, -1, 3]:
raise ValueError(
"Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
"Attr(axis): %s." % str(axis))
if axis == -1:
axis = 3
if name is None: if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -15233,7 +15287,8 @@ def maxout(x, groups, name=None): ...@@ -15233,7 +15287,8 @@ def maxout(x, groups, name=None):
helper.append_op( helper.append_op(
type="maxout", type="maxout",
inputs={"X": x}, inputs={"X": x},
attrs={"groups": groups}, attrs={"groups": groups,
"axis": axis},
outputs={"Out": out}) outputs={"Out": out})
return out return out
......
...@@ -44,13 +44,13 @@ def crop(data, offsets, crop_shape): ...@@ -44,13 +44,13 @@ def crop(data, offsets, crop_shape):
class TestCropTensorOp(OpTest): class TestCropTensorOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "crop_tensor" self.op_type = "crop_tensor"
self.crop_by_1D_shape = False self.shape_by_input = False
self.offset_by_input = False self.offset_by_input = False
self.unk_dim_idx = -1 self.unk_dim_idx = -1
self.attrs = {} self.attrs = {}
self.initTestCase() self.initTestCase()
if self.crop_by_1D_shape: if self.shape_by_input:
self.inputs = { self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"), 'X': np.random.random(self.x_shape).astype("float32"),
'Shape': np.array(self.crop_shape).astype("int32") 'Shape': np.array(self.crop_shape).astype("int32")
...@@ -65,11 +65,11 @@ class TestCropTensorOp(OpTest): ...@@ -65,11 +65,11 @@ class TestCropTensorOp(OpTest):
else: else:
self.attrs['offsets'] = self.offsets self.attrs['offsets'] = self.offsets
if self.unk_dim_idx != -1: crop_shape = [val for val in self.crop_shape]
self.crop_shape[self.unk_dim_idx] = self.x_shape[self.unk_dim_idx] for i in range(len(self.crop_shape)):
self.outputs = { if self.crop_shape[i] == -1:
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) crop_shape[i] = self.x_shape[i] - self.offsets[i]
} self.outputs = {'Out': crop(self.inputs['X'], self.offsets, crop_shape)}
def initTestCase(self): def initTestCase(self):
self.x_shape = (8, 8) self.x_shape = (8, 8)
...@@ -93,9 +93,8 @@ class TestCase1(TestCropTensorOp): ...@@ -93,9 +93,8 @@ class TestCase1(TestCropTensorOp):
class TestCase2(TestCropTensorOp): class TestCase2(TestCropTensorOp):
def initTestCase(self): def initTestCase(self):
self.x_shape = (12, 24) self.x_shape = (12, 24)
self.crop_shape = [-1, 8] #only the first dimension (batch) can be -1 self.crop_shape = [-1, 8]
self.offsets = [0, 0] self.offsets = [0, 0]
self.unk_dim_idx = 0
class TestCase3(TestCropTensorOp): class TestCase3(TestCropTensorOp):
...@@ -103,16 +102,15 @@ class TestCase3(TestCropTensorOp): ...@@ -103,16 +102,15 @@ class TestCase3(TestCropTensorOp):
self.x_shape = (4, 8, 16) self.x_shape = (4, 8, 16)
self.crop_shape = [2, 2, 3] self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3] self.offsets = [1, 5, 3]
self.crop_by_1D_shape = True self.shape_by_input = True
class TestCase4(TestCropTensorOp): class TestCase4(TestCropTensorOp):
def initTestCase(self): def initTestCase(self):
self.x_shape = (8, 3, 6, 6) self.x_shape = (8, 3, 6, 6)
self.crop_shape = [-1, 3, 4, 4] self.crop_shape = [-1, 3, -1, 4]
self.offsets = [0, 0, 0, 0] self.offsets = [0, 0, 1, 0]
self.crop_by_1D_shape = True self.shape_by_input = True
self.unk_dim_idx = 0
class TestCase5(TestCropTensorOp): class TestCase5(TestCropTensorOp):
...@@ -128,14 +126,13 @@ class TestCase6(TestCropTensorOp): ...@@ -128,14 +126,13 @@ class TestCase6(TestCropTensorOp):
self.x_shape = (2, 2, 4, 4, 4, 2) self.x_shape = (2, 2, 4, 4, 4, 2)
self.crop_shape = [1, 1, 4, 2, 2, 2] self.crop_shape = [1, 1, 4, 2, 2, 2]
self.offsets = [0, 0, 0, 0, 0, 0] self.offsets = [0, 0, 0, 0, 0, 0]
self.crop_by_1D_shape = True self.shape_by_input = True
self.offset_by_input = True self.offset_by_input = True
class TestCropTensorOp_attr_tensor(OpTest): class TestCropTensorOpTensorAttr(OpTest):
def setUp(self): def setUp(self):
self.op_type = "crop_tensor" self.op_type = "crop_tensor"
self.mixed_type = False
self.OffsetsTensor = False self.OffsetsTensor = False
self.ShapeTensor = True self.ShapeTensor = True
self.attrs = {} self.attrs = {}
...@@ -150,8 +147,7 @@ class TestCropTensorOp_attr_tensor(OpTest): ...@@ -150,8 +147,7 @@ class TestCropTensorOp_attr_tensor(OpTest):
'X': np.random.random(self.x_shape).astype("float32"), 'X': np.random.random(self.x_shape).astype("float32"),
'ShapeTensor': shape_tensor 'ShapeTensor': shape_tensor
} }
if self.mixed_type: self.attrs['shape'] = self.shape_attr
self.attrs['shape'] = self.shape_attr
if self.OffsetsTensor: if self.OffsetsTensor:
offsets_tensor = [] offsets_tensor = []
...@@ -162,17 +158,21 @@ class TestCropTensorOp_attr_tensor(OpTest): ...@@ -162,17 +158,21 @@ class TestCropTensorOp_attr_tensor(OpTest):
'X': np.random.random(self.x_shape).astype("float32"), 'X': np.random.random(self.x_shape).astype("float32"),
'OffsetsTensor': offsets_tensor 'OffsetsTensor': offsets_tensor
} }
else: self.attrs['offsets'] = self.offsets_attr
self.attrs['offsets'] = self.offsets
self.outputs = { self.attrs['shape'] = self.crop_shape
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) self.attrs['offsets'] = self.offsets
} crop_shape = [val for val in self.crop_shape]
for i in range(len(self.crop_shape)):
if self.crop_shape[i] == -1:
crop_shape[i] = self.x_shape[i] - self.offsets[i]
self.outputs = {'Out': crop(self.inputs['X'], self.offsets, crop_shape)}
def initTestCase(self): def initTestCase(self):
self.x_shape = (8, 8) self.x_shape = (8, 8)
self.crop_shape = (2, 2) self.crop_shape = (2, 2)
self.offsets = [1, 2] self.offsets = [1, 2]
self.shape_attr = [0, 0]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -181,38 +181,85 @@ class TestCropTensorOp_attr_tensor(OpTest): ...@@ -181,38 +181,85 @@ class TestCropTensorOp_attr_tensor(OpTest):
self.check_grad(["X"], "Out", max_relative_error=0.006) self.check_grad(["X"], "Out", max_relative_error=0.006)
class TestCropTensorOp_attr_tensor_case1(TestCropTensorOp_attr_tensor): class TestCropTensorOpTensorAttrCase1(TestCropTensorOpTensorAttr):
def init_data(self): def initTestCase(self):
self.x_shape = (16, 8, 32) self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3] self.crop_shape = [-1, -1, 3]
self.offsets = [1, 5, 3] self.offsets = [1, 5, 3]
self.shape_attr = [-1, -1, 3]
class TestCropTensorOp_attr_tensor_case2(TestCropTensorOp_attr_tensor): class TestCropTensorOpTensorAttrCase2(TestCropTensorOpTensorAttr):
def init_data(self): def initTestCase(self):
self.x_shape = (4, 8, 16, 8) self.x_shape = (4, 8, 16, 8)
self.crop_shape = [2, 2, 3, 4] self.crop_shape = [2, 2, 3, 4]
self.offsets = [1, 5, 3, 0] self.offsets = [1, 5, 3, 0]
self.shape_attr = [-1, -1, 3, 4] self.shape_attr = [0, 0, 3, 4]
self.mixed_type = True
class TestCropTensorOp_attr_tensor_case3(TestCropTensorOp_attr_tensor): class TestCropTensorOpTensorAttrCase3(TestCropTensorOpTensorAttr):
def init_data(self): def initTestCase(self):
self.x_shape = (16, 8, 32) self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3] self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3] self.offsets = [1, 5, 3]
self.offsets_attr = [-1, -1, 3]
self.ShapeTensor = False self.ShapeTensor = False
self.OffsetsTensor = True self.OffsetsTensor = True
class TestCropTensorOp_attr_tensor_case4(TestCropTensorOp_attr_tensor): class TestCropTensorOpTensorAttrCase4(TestCropTensorOpTensorAttr):
def init_data(self): def initTestCase(self):
self.x_shape = (16, 8, 32) self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3] self.crop_shape = [2, 2, 3]
self.shape_attr = [0, 2, 3]
self.offsets = [1, 5, 3] self.offsets = [1, 5, 3]
self.offsets_attr = [-1, -1, 3]
self.OffsetsTensor = True self.OffsetsTensor = True
class TestCropTensorException(OpTest):
def test_exception(self):
input1 = fluid.data(name="input1", shape=[2, 3, 6, 6], dtype="float32")
input2 = fluid.data(name="input2", shape=[2, 3, 6, 6], dtype="float16")
dim = fluid.data(name='dim', shape=[1], dtype='int32')
offset = fluid.data(name='offset', shape=[1], dtype='int32')
def attr_shape_type():
out = fluid.layers.crop_tensor(input1, shape=3)
def attr_shape_dtype():
out = fluid.layers.crop_tensor(input1, shape=[2, 2.0, 3, 3])
def attr_shape_value1():
out = fluid.layers.crop_tensor(input1, shape=[2, -2, dim, 3])
def attr_shape_value2():
out = fluid.layers.crop_tensor(input1, shape=[2, 0, dim, 3])
def attr_offsets_type():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=0)
def attr_offsets_dtype():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=[0, 1.0, 0, 0])
def attr_offsets_value():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=[0, -1, offset, 0])
def input_dtype():
out = fluid.layers.crop_tensor(input2, shape=[2, 2, 3, 3])
self.assertRaises(TypeError, attr_shape_type)
self.assertRaises(TypeError, attr_shape_dtype)
self.assertRaises(ValueError, attr_shape_value1)
self.assertRaises(ValueError, attr_shape_value2)
self.assertRaises(TypeError, attr_offsets_type)
self.assertRaises(TypeError, attr_offsets_dtype)
self.assertRaises(ValueError, attr_offsets_value)
self.assertRaises(TypeError, input_dtype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
...@@ -60,12 +62,15 @@ class TestLRNOp(OpTest): ...@@ -60,12 +62,15 @@ class TestLRNOp(OpTest):
'n': self.n, 'n': self.n,
'k': self.k, 'k': self.k,
'alpha': self.alpha, 'alpha': self.alpha,
'beta': self.beta 'beta': self.beta,
'data_format': self.data_format
} }
return attrs return attrs
def setUp(self): def setUp(self):
self.op_type = "lrn" self.op_type = "lrn"
self.init_test_case()
self.N = 2 self.N = 2
self.C = 3 self.C = 3
self.H = 5 self.H = 5
...@@ -77,11 +82,18 @@ class TestLRNOp(OpTest): ...@@ -77,11 +82,18 @@ class TestLRNOp(OpTest):
self.beta = 0.75 self.beta = 0.75
self.x = self.get_input() self.x = self.get_input()
self.out, self.mid_out = self.get_out() self.out, self.mid_out = self.get_out()
if self.data_format == 'NHWC':
self.x = np.transpose(self.x, [0, 2, 3, 1])
self.out = np.transpose(self.out, [0, 2, 3, 1])
self.mid_out = np.transpose(self.mid_out, [0, 2, 3, 1])
self.inputs = {'X': self.x} self.inputs = {'X': self.x}
self.outputs = {'Out': self.out, 'MidOut': self.mid_out} self.outputs = {'Out': self.out, 'MidOut': self.mid_out}
self.attrs = self.get_attrs() self.attrs = self.get_attrs()
def init_test_case(self):
self.data_format = 'NCHW'
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -89,5 +101,49 @@ class TestLRNOp(OpTest): ...@@ -89,5 +101,49 @@ class TestLRNOp(OpTest):
self.check_grad(['X'], 'Out', max_relative_error=0.01) self.check_grad(['X'], 'Out', max_relative_error=0.01)
class TestLRNOpAttrDataFormat(TestLRNOp):
def init_test_case(self):
self.data_format = 'NHWC'
class TestLRNAPI(OpTest):
def test_case(self):
data1 = fluid.data(name='data1', shape=[2, 4, 5, 5], dtype='float32')
data2 = fluid.data(name='data2', shape=[2, 5, 5, 4], dtype='float32')
out1 = fluid.layers.lrn(data1, data_format='NCHW')
out2 = fluid.layers.lrn(data2, data_format='NHWC')
data1_np = np.random.random((2, 4, 5, 5)).astype("float32")
data2_np = np.transpose(data1_np, [0, 2, 3, 1])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2],
return_numpy=True)
self.assertTrue(
np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2))))
def test_exception(self):
input1 = fluid.data(name="input1", shape=[2, 4, 5, 5], dtype="float32")
input2 = fluid.data(
name="input2", shape=[2, 4, 5, 5, 5], dtype="float32")
def _attr_data_fromat():
out = fluid.layers.lrn(input1, data_format='NDHW')
def _input_dim_size():
out = fluid.layers.lrn(input2)
self.assertRaises(ValueError, _attr_data_fromat)
self.assertRaises(ValueError, _input_dim_size)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -16,11 +16,16 @@ from __future__ import print_function ...@@ -16,11 +16,16 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
def maxout_forward_naive(input, groups): def maxout_forward_naive(input, groups, channel_axis):
s0, s1, s2, s3 = input.shape s0, s1, s2, s3 = input.shape
if channel_axis == 3:
return np.ndarray([s0, s1, s2, s3 // groups, groups], \
buffer = input, dtype=input.dtype).max(axis=(4))
return np.ndarray([s0, s1 // groups, groups, s2, s3], \ return np.ndarray([s0, s1 // groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2)) buffer = input, dtype=input.dtype).max(axis=(2))
...@@ -30,10 +35,11 @@ class TestMaxOutOp(OpTest): ...@@ -30,10 +35,11 @@ class TestMaxOutOp(OpTest):
self.op_type = "maxout" self.op_type = "maxout"
self.init_test_case() self.init_test_case()
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output = self.MaxOut_forward_naive(input, self.groups).astype("float32") output = self.MaxOut_forward_naive(input, self.groups,
self.axis).astype("float32")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = {'groups': self.groups} self.attrs = {'groups': self.groups, 'axis': self.axis}
self.outputs = {'Out': output.astype('float32')} self.outputs = {'Out': output.astype('float32')}
...@@ -47,6 +53,48 @@ class TestMaxOutOp(OpTest): ...@@ -47,6 +53,48 @@ class TestMaxOutOp(OpTest):
self.MaxOut_forward_naive = maxout_forward_naive self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2] self.shape = [100, 6, 2, 2]
self.groups = 2 self.groups = 2
self.axis = 1
class TestMaxOutOpAxis(TestMaxOutOp):
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 2, 2, 6] # NHWC format
self.groups = 2
self.axis = 3
class TestMaxOutOpAxisAPI(OpTest):
def test_axis(self):
data1 = fluid.data(name='data1', shape=[3, 6, 2, 2], dtype='float32')
data2 = fluid.data(name='data2', shape=[3, 2, 2, 6], dtype='float32')
out1 = fluid.layers.maxout(data1, groups=2, axis=1)
out2 = fluid.layers.maxout(data2, groups=2, axis=-1)
data1_np = np.random.random((3, 6, 2, 2)).astype("float32")
data2_np = np.transpose(data1_np, [0, 2, 3, 1])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2],
return_numpy=True)
self.assertTrue(
np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2))))
def test_exception(self):
input = fluid.data(name="input", shape=[2, 4, 6, 6], dtype="float32")
def _attr_axis():
out = fluid.layers.maxout(input, groups=2, axis=2)
self.assertRaises(ValueError, _attr_axis)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册