未验证 提交 3848f720 编写于 作者: Z Zhang Ting 提交者: GitHub

[cherry-pick] fix crop_tensor, maxout and lrn (#21302)

* [cherry-pick] All elements in attr(shape) of crop_tensor can be -1 and int32/64 kernel registered (#20756)

* All elements in attr(shape) of crop_tensor can be -1, test=develop, test=document_preview

* fix the bug that attr(offsets) should be initialized, test=develop

* [cherry-pick] maxout supports channel_last input (#20846)

* maxout support channel_last input, test=develop

* modified details of Input(X) and Attr(groups, axis) in doc, test=develop

* [cherry-pick] lrn supports channel_last input, test=develop (#20954)
上级 9f004548
......@@ -31,8 +31,9 @@ class CropTensorOp : public framework::OperatorWithKernel {
"Input(X) of Op(crop_tensor) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of Op(crop_tensor) should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto offsets = ctx->Attrs().Get<std::vector<int>>("offsets");
if (ctx->HasInputs("ShapeTensor")) {
// top prority shape
auto inputs_name = ctx->Inputs("ShapeTensor");
......@@ -43,15 +44,19 @@ class CropTensorOp : public framework::OperatorWithKernel {
"Op(fluid.layers.crop_tensor).");
auto out_dims = std::vector<int>(inputs_name.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != -1) {
if (shape[i] > 0) {
out_dims[i] = static_cast<int64_t>(shape[i]);
} else {
if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
out_dims[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
}
}
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
return;
}
auto x_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Shape")) {
auto shape_dim = ctx->GetInputDim("Shape");
PADDLE_ENFORCE_EQ(
......@@ -78,11 +83,17 @@ class CropTensorOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(int64_t(shape.size()), x_dim.size(),
"Attr(shape)'size of Op(crop_tensor) should be equal to "
"dimention size of input tensor.");
std::vector<int64_t> tensor_shape(shape.size());
std::vector<int64_t> out_shape(shape.size(), -1);
for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]);
if (shape[i] > 0) {
out_shape[i] = static_cast<int64_t>(shape[i]);
} else {
if (shape[i] == -1 && offsets[i] != -1 && x_dim[i] != -1) {
out_shape[i] = x_dim[i] - static_cast<int64_t>(offsets[i]);
}
}
}
ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}
framework::OpKernelType GetExpectedKernelType(
......@@ -294,8 +305,12 @@ REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad);
REGISTER_OP_CPU_KERNEL(
crop_tensor,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, double>);
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, double>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, int>,
ops::CropTensorKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::CropTensorGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -17,8 +17,12 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
crop_tensor,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, double>);
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, double>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, int>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -50,29 +50,28 @@ inline std::vector<int> get_new_data(
}
static framework::DDim ValidateShape(const std::vector<int> shape,
const std::vector<int> offsets,
const framework::DDim& in_dims) {
auto in_dim_size = in_dims.size();
auto shape_size = shape.size();
PADDLE_ENFORCE_EQ(
in_dim_size, shape_size,
"Input(ShapeTensor)'s dimension size of Op(crop_tensor) should be equal "
"to that of input tensor. "
"Attr(shape)'s size of Op(crop_tensor) should be equal "
"to that of input Tensor. "
"Please check the Attr(shape)'s size of Op(fluid.layers.crop_tensor).");
const int64_t unk_dim_val = -1;
int unk_dim_idx = -1;
std::vector<int64_t> output_shape(shape.size(), 0);
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
PADDLE_ENFORCE_EQ(unk_dim_idx, -1,
"Only one element of shape can be unknown.");
PADDLE_ENFORCE_EQ(i, 0, "Only the first element of shape can be -1.");
unk_dim_idx = i;
if (shape[i] <= 0 && in_dims[i] > 0) {
PADDLE_ENFORCE_NE(
shape[i], 0,
"The element in Attr(shape) of Op(crop_tensor) should not be zero.");
PADDLE_ENFORCE_EQ(shape[i], -1,
"When the element in Attr(shape) of Op(crop_tensor) is "
"negative, only -1 is supported.");
output_shape[i] = in_dims[i] - offsets[i];
} else {
PADDLE_ENFORCE_GT(shape[i], 0,
"Each element of shape must be greater than 0 "
"except the first element.");
output_shape[i] = static_cast<int64_t>(shape[i]);
}
output_shape[i] = static_cast<int64_t>(shape[i]);
}
return framework::make_ddim(output_shape);
......@@ -164,21 +163,15 @@ void CropTensorFunction(const framework::ExecutionContext& context) {
shape.push_back(out_dims[i]);
}
}
out_dims = ValidateShape(shape, x->dims());
if (out_dims[0] == -1) {
out_dims[0] = x->dims()[0];
}
out->mutable_data<T>(out_dims, context.GetPlace());
auto x_stride = framework::stride(x->dims());
auto offsets = GetOffsets(context);
int64_t offset = 0;
out_dims = ValidateShape(shape, offsets, x->dims());
out->mutable_data<T>(out_dims, context.GetPlace());
for (size_t i = 0; i < offsets.size(); ++i) {
PADDLE_ENFORCE_LE(
offsets[i] + shape[i], x_dims[i],
"The sum of the Attr(offsets) and Attr(shape) of Op(crop_tensor) "
"should be less than or equal to corresponding input dimension size.");
offset += (x_stride[i] * offsets[i]);
}
auto x_tensor = EigenTensor<T, D>::From(*x);
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/lrn_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
......@@ -23,18 +25,41 @@ namespace paddle {
namespace operators {
using framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
struct LRNFunctor<platform::CPUDeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) {
const T* idata = input.data<T>();
T k, T alpha, T beta, const DataLayout data_layout) {
auto place = ctx.GetPlace();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
T* odata = out->mutable_data<T>(place);
T* mdata = mid->mutable_data<T>(place);
math::Transpose<platform::CPUDeviceContext, T, 4> transpose;
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
Tensor in_transpose, mid_transpose, out_transpose;
// if channel_last, transpose to channel_first
if (data_layout == DataLayout::kNHWC) {
auto in_dims = input.dims();
std::vector<int64_t> shape(
{in_dims[0], in_dims[3], in_dims[1], in_dims[2]});
in_transpose.mutable_data<T>(framework::make_ddim(shape), place);
mid_transpose.mutable_data<T>(framework::make_ddim(shape), place);
out_transpose.mutable_data<T>(framework::make_ddim(shape), place);
std::vector<int> axis = {0, 3, 1, 2};
transpose(dev_ctx, input, &in_transpose, axis);
} else {
in_transpose = input;
mid_transpose = *mid;
out_transpose = *out;
mid_transpose.mutable_data<T>(mid->dims(), place);
out_transpose.mutable_data<T>(out->dims(), place);
}
const T* idata = in_transpose.data<T>();
T* odata = out_transpose.data<T>();
T* mdata = mid_transpose.data<T>();
Tensor squared;
T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place);
std::memset(sdata, 0, sizeof(T) * squared.numel());
......@@ -67,6 +92,13 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
// compute the final output
blas.VPOW(mid->numel(), mdata, -beta, odata);
blas.VMUL(mid->numel(), odata, idata, odata);
// if channel_last, transpose the output(NCHW) to channel_last
if (data_layout == DataLayout::kNHWC) {
std::vector<int> axis = {0, 2, 3, 1};
transpose(dev_ctx, mid_transpose, mid, axis);
transpose(dev_ctx, out_transpose, out, axis);
}
}
};
template struct LRNFunctor<platform::CPUDeviceContext, float>;
......@@ -78,7 +110,7 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta) {
int n, T alpha, T beta, const DataLayout data_layout) {
T ratio = -2 * alpha * beta;
auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
x_g_e = x_g_e.constant(0.0);
......@@ -93,17 +125,17 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
const int end = start + n;
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto offsets = Eigen::array<int, 4>({{m, i, 0, 0}});
auto extents = Eigen::array<int, 4>({{1, 1, H, W}});
if (data_layout == DataLayout::kNHWC) {
offsets = Eigen::array<int, 4>({{m, 0, 0, i}});
extents = Eigen::array<int, 4>({{1, H, W, 1}});
}
auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_x = e_x.slice(offsets, extents);
auto i_x_g = e_x_g.slice(offsets, extents);
auto i_out_g = e_out_g.slice(offsets, extents);
auto i_mid = e_mid.slice(offsets, extents);
i_x_g = i_mid.pow(-beta) * i_out_g;
for (int c = start; c < end; c++) {
......@@ -112,14 +144,14 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
continue;
}
auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
if (data_layout != DataLayout::kNHWC) {
offsets = Eigen::array<int, 4>({{m, ch, 0, 0}});
} else {
offsets = Eigen::array<int, 4>({{m, 0, 0, ch}});
}
auto c_out = e_out.slice(offsets, extents);
auto c_mid = e_mid.slice(offsets, extents);
auto c_out_g = e_out_g.slice(offsets, extents);
i_x_g += ratio * c_out_g * c_out * i_x / c_mid;
}
......@@ -156,9 +188,8 @@ class LRNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
......@@ -242,8 +273,8 @@ $$
Function implementation:
Inputs and outpus are in NCHW format, while input.shape.ndims() equals 4.
And dimensions 0 ~ 3 represent batch size, feature maps, rows,
Inputs and outpus are in NCHW or NHWC format, while input.shape.ndims() equals 4.
If NCHW, the dimensions 0 ~ 3 represent batch size, feature maps, rows,
and columns, respectively.
Input and Output in the formula above is for each map(i) of one image, and
......@@ -275,9 +306,8 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
......
......@@ -17,15 +17,20 @@ limitations under the License. */
namespace paddle {
namespace operators {
using DataLayout = framework::DataLayout;
template <typename T>
__global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C,
int H, int W, int size, T k, T alpha) {
int H, int W, int size, T k, T alpha,
const DataLayout data_layout) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < img_size) {
const int w = idx % W;
const int h = (idx / W) % H;
const int n = idx / W / H;
const int offset = (n * C * H + h) * W + w;
const int offset =
(data_layout != DataLayout::kNHWC ? (n * C * H + h) * W + w
: ((n * H + h) * W + w) * C);
in += offset;
mid += offset;
......@@ -37,15 +42,21 @@ __global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C,
int index = 0;
while (index < C + post_pad) {
if (index < C) {
T val = in[index * step];
int in_idx = (data_layout != DataLayout::kNHWC ? index * step : index);
T val = in[in_idx];
accum += val * val;
}
if (index >= size) {
T val = in[(index - size) * step];
int in_idx = (data_layout != DataLayout::kNHWC ? (index - size) * step
: index - size);
T val = in[in_idx];
accum -= val * val;
}
if (index >= post_pad) {
mid[(index - post_pad) * step] = k + accum * alpha;
int mid_idx =
(data_layout != DataLayout::kNHWC ? (index - post_pad) * step
: index - post_pad);
mid[mid_idx] = k + accum * alpha;
}
++index;
}
......@@ -64,14 +75,14 @@ __global__ void KeCMRNormOutput(int input_size, const T* in, const T* mid,
template <typename T>
void CrossMapNormal(const framework::ExecutionContext& ctx, const T* inputs,
T* outputs, T* mid, int N, int C, int H, int W, int n, T k,
T alpha, T beta) {
T alpha, T beta, const DataLayout data_layout) {
int img_size = N * H * W;
const int block_size = 1024;
int grid_size = (img_size + block_size - 1) / block_size;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
KeCMRNormFillScale<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
img_size, inputs, mid, C, H, W, n, k, alpha);
img_size, inputs, mid, C, H, W, n, k, alpha, data_layout);
int input_size = N * H * W * C;
grid_size = (input_size + block_size - 1) / block_size;
......@@ -84,10 +95,11 @@ struct LRNFunctor<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) {
CrossMapNormal<T>(
ctx, input.data<T>(), out->mutable_data<T>(ctx.GetPlace()),
mid->mutable_data<T>(ctx.GetPlace()), N, C, H, W, n, k, alpha, beta);
T k, T alpha, T beta, const DataLayout data_layout) {
CrossMapNormal<T>(ctx, input.data<T>(),
out->mutable_data<T>(ctx.GetPlace()),
mid->mutable_data<T>(ctx.GetPlace()), N, C, H, W, n, k,
alpha, beta, data_layout);
}
};
......@@ -97,14 +109,16 @@ template struct LRNFunctor<platform::CUDADeviceContext, double>;
template <typename T>
__global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
const T* mid, T* x_g, const T* out_g, int C,
int H, int W, int size, T negative_beta,
T ratio) {
int H, int W, int size, T negative_beta, T ratio,
const DataLayout data_layout) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < img_size) {
const int w = idx % W;
const int h = (idx / W) % H;
const int n = idx / W / H;
const int offset = (n * C * H + h) * W + w;
const int offset =
(data_layout != DataLayout::kNHWC ? (n * C * H + h) * W + w
: ((n * H + h) * W + w) * C);
x += offset;
out += offset;
mid += offset;
......@@ -120,18 +134,20 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
// TODO(gongwb): optimize this with thread shared array.
while (index < C + post_pad) {
if (index < C) {
x_g[index * step] = 0.0;
accum += out_g[index * step] * out[index * step] / mid[index * step];
int idx = (data_layout != DataLayout::kNHWC ? index * step : index);
x_g[idx] = 0.0;
accum += out_g[idx] * out[idx] / mid[idx];
}
if (index >= size) {
accum -= out_g[(index - size) * step] * out[(index - size) * step] /
mid[(index - size) * step];
int idx = (data_layout != DataLayout::kNHWC ? (index - size) * step
: index - size);
accum -= out_g[idx] * out[idx] / mid[idx];
}
if (index >= post_pad) {
x_g[(index - post_pad) * step] +=
out_g[(index - post_pad) * step] *
pow(mid[(index - post_pad) * step], negative_beta) -
ratio * x[(index - post_pad) * step] * accum;
int idx = (data_layout != DataLayout::kNHWC ? (index - post_pad) * step
: index - post_pad);
x_g[idx] +=
out_g[idx] * pow(mid[idx], negative_beta) - ratio * x[idx] * accum;
}
++index;
}
......@@ -141,7 +157,8 @@ __global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
template <typename T>
void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x,
const T* out, const T* mid, T* x_g, const T* out_g,
int N, int C, int H, int W, int n, T alpha, T beta) {
int N, int C, int H, int W, int n, T alpha, T beta,
const DataLayout data_layout) {
int img_size = N * H * W;
const int block_size = 1024;
......@@ -149,8 +166,8 @@ void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x,
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
KeCMRNormDiff<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta,
2.0f * alpha * beta);
img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta, 2.0f * alpha * beta,
data_layout);
}
template <typename T>
......@@ -159,10 +176,10 @@ struct LRNGradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta) {
int n, T alpha, T beta, const DataLayout data_layout) {
CrossMapNormalGrad<T>(ctx, x.data<T>(), out.data<T>(), mid.data<T>(),
x_g->mutable_data<T>(ctx.GetPlace()), out_g.data<T>(),
N, C, H, W, n, alpha, beta);
N, C, H, W, n, alpha, beta, data_layout);
}
};
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -21,12 +23,15 @@ limitations under the License. */
namespace paddle {
namespace operators {
using DataLayout = framework::DataLayout;
template <typename place, typename T>
struct LRNFunctor {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta);
T k, T alpha, T beta,
const DataLayout data_layout = DataLayout::kAnyLayout);
};
template <typename DeviceContext, typename T>
......@@ -42,11 +47,14 @@ class LRNKernel : public framework::OpKernel<T> {
const Tensor& x = *ctx.Input<Tensor>("X");
auto x_dims = x.dims();
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
// NCHW
int N = x_dims[0];
int C = x_dims[1];
int H = x_dims[2];
int W = x_dims[3];
int C = (data_layout != DataLayout::kNHWC ? x_dims[1] : x_dims[3]);
int H = (data_layout != DataLayout::kNHWC ? x_dims[2] : x_dims[1]);
int W = (data_layout != DataLayout::kNHWC ? x_dims[3] : x_dims[2]);
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
......@@ -65,7 +73,7 @@ class LRNKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");
LRNFunctor<DeviceContext, T> f;
f(ctx, x, out, mid, N, C, H, W, n, k, alpha, beta);
f(ctx, x, out, mid, N, C, H, W, n, k, alpha, beta, data_layout);
}
};
......@@ -75,7 +83,8 @@ struct LRNGradFunctor {
const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta);
int n, T alpha, T beta,
const DataLayout data_layout = DataLayout::kAnyLayout);
};
/**
......@@ -106,15 +115,18 @@ class LRNGradKernel : public framework::OpKernel<T> {
const Tensor& out = *ctx.Input<Tensor>("Out");
const Tensor& out_g = *ctx.Input<Tensor>(framework::GradVarName("Out"));
const Tensor& mid = *ctx.Input<Tensor>("MidOut");
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
auto x_g = ctx.Output<Tensor>(framework::GradVarName("X"));
x_g->mutable_data<T>(ctx.GetPlace());
auto x_dims = x.dims();
int N = x_dims[0];
int C = x_dims[1];
int H = x_dims[2];
int W = x_dims[3];
int C = (data_layout != DataLayout::kNHWC ? x_dims[1] : x_dims[3]);
int H = (data_layout != DataLayout::kNHWC ? x_dims[2] : x_dims[1]);
int W = (data_layout != DataLayout::kNHWC ? x_dims[3] : x_dims[2]);
int n = ctx.Attr<int>("n");
T alpha = ctx.Attr<T>("alpha");
......@@ -125,7 +137,7 @@ class LRNGradKernel : public framework::OpKernel<T> {
"is_test attribute should be set to False in training phase.");
LRNGradFunctor<DeviceContext, T> f;
f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta);
f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta, data_layout);
}
};
......
......@@ -18,35 +18,45 @@ namespace paddle {
namespace operators {
namespace math {
// All tensors are in NCHW format, and the groups must be greater than 1
// All tensors are in NCHW or NHWC format, and the groups must be greater than 1
template <typename T>
class MaxOutFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* output,
int groups) {
const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];
int fea_size = input_height * input_width;
// c_size means the output size of each sample
int c_size = fea_size * output_channels;
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX);
int input_idx, output_idx;
for (int ph = 0; ph < groups; ++ph) {
T x = input_data[(new_bindex + new_cindex) * groups +
ph * fea_size + f];
if (axis == 1) {
input_idx =
(new_bindex + new_cindex) * groups + ph * fea_size + f;
} else {
input_idx = (new_bindex + f * output_channels + c) * groups + ph;
}
T x = input_data[input_idx];
ele = ele > x ? ele : x;
}
output_data[(new_bindex + new_cindex + f)] = ele;
if (axis == 1) {
output_idx = new_bindex + new_cindex + f;
} else {
output_idx = new_bindex + f * output_channels + c;
}
output_data[output_idx] = ele;
}
}
}
......@@ -59,11 +69,12 @@ class MaxOutGradFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups) {
const framework::Tensor& output_grad, const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];
int fea_size = input_height * input_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
......@@ -75,11 +86,18 @@ class MaxOutGradFunctor<platform::CPUDeviceContext, T> {
for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
int input_idx0 = (blen + clen) * groups + f;
int input_idx0, output_idx;
bool continue_match = true;
int output_idx = blen + clen + f;
if (axis == 1) {
input_idx0 = (blen + clen) * groups + f;
output_idx = blen + clen + f;
} else {
input_idx0 = (blen + f * output_channels + c) * groups;
output_idx = blen + f * output_channels + c;
}
for (int g = 0; g < groups && continue_match; ++g) {
int input_idx = input_idx0 + fea_size * g;
int idx_offset = (axis == 1 ? fea_size * g : g);
int input_idx = input_idx0 + idx_offset;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
......
......@@ -22,8 +22,8 @@ namespace math {
template <typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
const int channels, const int input_height,
const int input_width, int groups,
T* output_data) {
const int input_width, const int groups,
const int axis, T* output_data) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -31,13 +31,22 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int channel_idx, feat_idx, data_idx;
if (axis == 1) {
channel_idx = batch_offset / feat_len;
feat_idx = batch_offset % feat_len;
data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
} else {
channel_idx = batch_offset % channels;
feat_idx = batch_offset / channels;
data_idx =
(batch_idx * size + feat_idx * channels + channel_idx) * groups;
}
T ele = static_cast<T>(-FLT_MAX);
for (int g = 0; g < groups; ++g) {
T x = input_data[data_idx + g * feat_len];
int idx_offset = (axis == 1 ? g * feat_len : g);
T x = input_data[data_idx + idx_offset];
ele = ele > x ? ele : x;
}
output_data[i] = ele;
......@@ -48,7 +57,7 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
const T* output_data, const T* output_grad,
T* input_grad, const int channels,
const int input_height, const int input_width,
int groups) {
const int groups, const int axis) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -56,15 +65,24 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int channel_idx, feat_idx, data_idx;
if (axis == 1) {
channel_idx = batch_offset / feat_len;
feat_idx = batch_offset % feat_len;
data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
} else {
channel_idx = batch_offset % channels;
feat_idx = batch_offset / channels;
data_idx =
(batch_idx * size + feat_idx * channels + channel_idx) * groups;
}
int max_index = -1;
bool continue_match = true;
for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len;
int idx_offset = (axis == 1 ? g * feat_len : g);
if (input_data[data_idx + idx_offset] == output_data[i]) {
max_index = data_idx + idx_offset;
continue_match = false;
break;
}
......@@ -75,21 +93,19 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
}
}
/*
* All tensors are in NCHW format.
* All tensors are in NCHW or NHWC format.
*/
template <typename T>
class MaxOutFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* output,
int groups) {
const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
......@@ -100,11 +116,11 @@ class MaxOutFunctor<platform::CUDADeviceContext, T> {
KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width, groups,
output_data);
axis, output_data);
}
};
/*
* All tensors are in NCHW format.
* All tensors are in NCHW or NHWC format.
*/
template <typename T>
class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
......@@ -112,14 +128,13 @@ class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups) {
const framework::Tensor& output_grad, const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
......@@ -132,7 +147,7 @@ class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups);
input_channels, input_height, input_width, groups, axis);
}
};
......
......@@ -26,7 +26,8 @@ template <typename DeviceContext, typename T>
class MaxOutFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* output, int groups);
framework::Tensor* output, const int groups,
const int axis = 1);
};
template <typename DeviceContext, class T>
......@@ -35,7 +36,8 @@ class MaxOutGradFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups);
const framework::Tensor& output_grad, const int groups,
const int axis = 1);
};
} // namespace math
} // namespace operators
......
......@@ -23,25 +23,27 @@ using framework::Tensor;
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of maxout operator with data type of "
"float32. The format of input tensor is NCHW. Where N is batch size,"
" C is the number of channels, H and W is the height and width of "
"feature.");
AddInput("X",
"A 4-D Tensor with data type of float32 or float64. "
"The data format is NCHW or NHWC. Where N is "
"batch size, C is the number of channels, "
"H and W is the height and width of "
"feature. ");
AddOutput("Out",
"(Tensor) The output tensor of maxout operator."
"The data type is float32."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
"A 4-D Tensor with same data type and data format "
"with input Tensor. ");
AddAttr<int>(
"groups",
"(int),"
"Specifies how many groups the input tensor will be split"
"in the channel dimension. And the number of output channel is "
"the number of channels divided by groups.");
"Specifies how many groups the input tensor will be split into "
"at the channel dimension. And the number of output channel is "
"the number of channels divided by groups. ");
AddAttr<int>(
"axis",
"Specifies the index of channel dimension where maxout will "
"be performed. It should be 1 when data format is NCHW, -1 or 3 "
"when data format is NHWC. "
"Default: 1. ")
.SetDefault(1);
AddComment(R"DOC(
MaxOut Operator.
......@@ -70,17 +72,19 @@ class MaxOutOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MaxoutOpshould not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MaxoutOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of MaxoutOpshould not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups");
int axis = ctx->Attrs().Get<int>("axis");
// check groups > 1
PADDLE_ENFORCE_GT(groups, 1, "groups should be larger than 1 in maxoutop");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups});
output_shape.push_back(in_x_dims[2]);
output_shape.push_back(in_x_dims[3]);
PADDLE_ENFORCE_GT(groups, 1,
"Attr(groups) of Op(maxout) should be larger than 1.");
std::vector<int64_t> output_shape(
{in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
output_shape[axis] = in_x_dims[axis] / groups;
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
......
......@@ -30,10 +30,11 @@ class MaxOutKernel : public framework::OpKernel<T> {
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis");
math::MaxOutFunctor<DeviceContext, T> maxout_forward;
maxout_forward(context.template device_context<DeviceContext>(), *in_x, out,
groups);
groups, axis);
}
};
......@@ -47,13 +48,15 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis");
auto& device_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
math::MaxOutGradFunctor<DeviceContext, T> maxout_backward;
maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups);
maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups,
axis);
}
}
};
......
......@@ -9334,7 +9334,8 @@ def lod_append(x, level):
return out
def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None,
data_format='NCHW'):
"""
This operator implements the Local Response Normalization Layer.
This layer performs a type of "lateral inhibition" by normalizing over local input regions.
......@@ -9355,13 +9356,18 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
Args:
input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W], where N is the batch size, C is the input channel, H is Height, W is weight. The data type is float32. The rank of this tensor must be 4, otherwise it will raise ValueError.
input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W] or [N, H, W, C],
where N is the batch size, C is the input channel, H is Height, W is weight. The data
type is float32. The rank of this tensor must be 4, otherwise it will raise ValueError.
n (int, optional): The number of channels to sum over. Default: 5
k (float, optional): An offset, positive. Default: 1.0
alpha (float, optional): The scaling parameter, positive. Default:1e-4
beta (float, optional): The exponent, positive. Default:0.75
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
name (str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`
data_format(str, optional): The data format of the input and output data. An optional string
from: `"NCHW"`, `"NHWC"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. Default: 'NCHW'.
Returns:
Variable: A tensor variable storing the transformation result with the same shape and data type as input.
......@@ -9384,8 +9390,12 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
if dims != 4:
raise ValueError(
"dims of input must be 4(not %d), and it's order must be NCHW" %
"Input's dimension size of Op(lrn) must be 4, but received %d." %
(dims))
if data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Attr(data_format) of Op(lrn) got wrong value: received " +
data_format + " but only NCHW or NHWC supported.")
mid_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
......@@ -9397,10 +9407,13 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
"Out": lrn_out,
"MidOut": mid_out,
},
attrs={"n": n,
"k": k,
"alpha": alpha,
"beta": beta})
attrs={
"n": n,
"k": k,
"alpha": alpha,
"beta": beta,
"data_format": data_format
})
return lrn_out
......@@ -11547,7 +11560,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
* Case 1 (input is a 2-D Tensor):
Input:
X.shape = [3. 5]
X.shape = [3, 5]
X.data = [[0, 1, 2, 0, 0],
[0, 3, 4, 0, 0],
[0, 0, 0, 0, 0]]
......@@ -11555,8 +11568,9 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
shape = [2, 2]
offsets = [0, 1]
Output:
Out = [[1, 2],
[3, 4]]
Out.shape = [2, 2]
Out.data = [[1, 2],
[3, 4]]
* Case 2 (input is a 3-D Tensor):
Input:
X.shape = [2, 3, 4]
......@@ -11567,24 +11581,23 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
[0, 6, 7, 8],
[0, 0, 0, 0]]]
Parameters:
shape = [2, 2, 3]
shape = [2, 2, -1]
offsets = [0, 0, 1]
Output:
Out = [[[1, 2, 3],
[5, 6, 7]],
[[3, 4, 5],
[6, 7, 8]]]
Out.shape = [2, 2, 3]
Out.data = [[[1, 2, 3],
[5, 6, 7]],
[[3, 4, 5],
[6, 7, 8]]]
Parameters:
x (Variable): 1-D to 6-D Tensor, the data type is float32 or float64.
x (Variable): 1-D to 6-D Tensor, the data type is float32, float64, int32 or int64.
shape (list|tuple|Variable): The output shape is specified
by `shape`. Its data type is int32. If a list/tuple, it's length must be
the same as the dimension size of `x`. If a Variable, it shoule be a 1-D Tensor.
When it is a list, each element can be an integer or a Tensor of shape: [1].
If Variable contained, it is suitable for the case that the shape may
be changed each iteration. Only the first element of list/tuple can be
set to -1, it means that the first dimension's size of the output is the same
as the input.
If Variable contained, it is suitable for the case that the shape may
be changed each iteration.
offsets (list|tuple|Variable, optional): Specifies the cropping
offsets at each dimension. Its data type is int32. If a list/tuple, it's length
must be the same as the dimension size of `x`. If a Variable, it shoule be a 1-D
......@@ -11598,8 +11611,12 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
Variable: The cropped Tensor has same data type with `x`.
Raises:
ValueError: If shape is not a list, tuple or Variable.
ValueError: If offsets is not None and not a list, tuple or Variable.
TypeError: If the data type of `x` is not in: float32, float64, int32, int64.
TypeError: If `shape` is not a list, tuple or Variable.
TypeError: If the data type of `shape` is not int32.
TypeError: If `offsets` is not None and not a list, tuple or Variable.
TypeError: If the data type of `offsets` is not int32.
ValueError: If the element in `offsets` is less than zero.
Examples:
......@@ -11615,7 +11632,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
# crop0.shape = [-1, -1, -1], it means crop0.shape[0] = x.shape[0] in runtime.
# or shape is a list in which each element is a constant
crop1 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3])
crop1 = fluid.layers.crop_tensor(x, shape=[-1, -1, 3], offsets=[0, 1, 0])
# crop1.shape = [-1, 2, 3]
# or shape is a list in which each element is a constant or Variable
......@@ -11637,70 +11654,98 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
"""
helper = LayerHelper('crop_tensor', **locals())
if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']:
raise TypeError(
"Input(x)'s dtype of Op(crop_tensor) must be float32, float64, int32 or int64, "
"but received %s." % (convert_dtype(x.dtype)))
if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)):
raise ValueError("The shape should be a list, tuple or Variable.")
raise TypeError(
"Attr(shape) of Op(crop_tensor) should be a list, tuple or Variable."
)
if offsets is None:
offsets = [0] * len(x.shape)
if not (isinstance(offsets, list) or isinstance(offsets, tuple) or \
isinstance(offsets, Variable)):
raise ValueError("The offsets should be a list, tuple or Variable.")
raise TypeError(
"Attr(offsets) of Op(crop_tensor) should be a list, tuple or Variable."
)
out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x}
attrs = {}
def contain_var(input_list):
def _contain_var(input_list):
for ele in input_list:
if isinstance(ele, Variable):
return True
return False
def _attr_shape_check(shape_val):
if not isinstance(shape_val, int):
raise TypeError(
"Attr(shape)'s dtype of Op(crop_tensor) should be int32, but received: %s."
% type(shape_val))
if shape_val == 0:
raise ValueError(
"Attr(shape) of Op(crop_tensor) should not be zero, but received: %s."
% str(shape_val))
if shape_val < -1:
raise ValueError(
"When the element in Attr(shape) of Op(crop_tensor) is negative, only -1 is supported, but received: %s."
% str(shape_val))
def _attr_offsets_check(offset_val):
if not isinstance(offset_val, int):
raise TypeError(
"Attr(offsets)'s dtype of Op(crop_tensor) should be int32, but received: %s."
% type(offset_val))
if offset_val < 0:
raise ValueError(
"Attr(offsets) of Op(crop_tensor) should be greater or equal to zero, but received: %s."
% str(offset_val))
if isinstance(offsets, Variable):
offsets.stop_gradient = True
ipts['Offsets'] = offsets
elif contain_var(offsets):
attrs['offsets'] = [-1] * len(x.shape)
elif _contain_var(offsets):
new_offsets_tensor = []
offsets_attr = []
for dim in offsets:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_offsets_tensor.append(dim)
offsets_attr.append(-1)
else:
assert (isinstance(dim, int))
assert dim >= 0, ("offsets should be greater or equal to zero.")
_attr_offsets_check(dim)
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_offsets_tensor.append(temp_out)
offsets_attr.append(dim)
ipts['OffsetsTensor'] = new_offsets_tensor
attrs['offsets'] = offsets_attr
else:
for offset in offsets:
_attr_offsets_check(offset)
attrs['offsets'] = offsets
unk_dim_idx = -1
if isinstance(shape, Variable):
shape.stop_gradient = True
ipts['Shape'] = shape
elif contain_var(shape):
elif _contain_var(shape):
new_shape_tensor = []
shape_attr = []
for dim_idx, dim_size in enumerate(shape):
for dim_size in shape:
if isinstance(dim_size, Variable):
dim_size.stop_gradient = True
new_shape_tensor.append(dim_size)
shape_attr.append(-1)
shape_attr.append(0)
else:
assert (isinstance(dim_size, int))
if dim_size == -1:
assert unk_dim_idx == -1, (
"Only one element in shape can be unknown.")
assert dim_idx == 0, (
"Only the first element in shape can be -1.")
unk_dim_idx = dim_idx
else:
assert dim_size > 0, (
"Each dimension size given in shape must be greater than zero."
)
_attr_shape_check(dim_size)
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant(
[1], 'int32', dim_size, force_cpu=True, out=temp_out)
......@@ -11709,6 +11754,8 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
ipts['ShapeTensor'] = new_shape_tensor
attrs['shape'] = shape_attr
else:
for dim_size in shape:
_attr_shape_check(dim_size)
attrs['shape'] = shape
helper.append_op(
......@@ -15195,22 +15242,23 @@ def sigmoid_cross_entropy_with_logits(x,
@templatedoc()
def maxout(x, groups, name=None):
def maxout(x, groups, name=None, axis=1):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
groups(${groups_type}): ${groups_comment}
groups(int): ${groups_comment}
axis(int, optional): ${axis_comment}
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Variable:
out(${out_type}): ${out_comment}
Variable: ${out_comment}
Raises:
ValueError: If `axis` is not 1, -1 or 3.
Examples:
.. code-block:: python
......@@ -15223,6 +15271,12 @@ def maxout(x, groups, name=None):
out = fluid.layers.maxout(input, groups=2)
"""
helper = LayerHelper("maxout", **locals())
if axis not in [1, -1, 3]:
raise ValueError(
"Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
"Attr(axis): %s." % str(axis))
if axis == -1:
axis = 3
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -15233,7 +15287,8 @@ def maxout(x, groups, name=None):
helper.append_op(
type="maxout",
inputs={"X": x},
attrs={"groups": groups},
attrs={"groups": groups,
"axis": axis},
outputs={"Out": out})
return out
......
......@@ -44,13 +44,13 @@ def crop(data, offsets, crop_shape):
class TestCropTensorOp(OpTest):
def setUp(self):
self.op_type = "crop_tensor"
self.crop_by_1D_shape = False
self.shape_by_input = False
self.offset_by_input = False
self.unk_dim_idx = -1
self.attrs = {}
self.initTestCase()
if self.crop_by_1D_shape:
if self.shape_by_input:
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'Shape': np.array(self.crop_shape).astype("int32")
......@@ -65,11 +65,11 @@ class TestCropTensorOp(OpTest):
else:
self.attrs['offsets'] = self.offsets
if self.unk_dim_idx != -1:
self.crop_shape[self.unk_dim_idx] = self.x_shape[self.unk_dim_idx]
self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
}
crop_shape = [val for val in self.crop_shape]
for i in range(len(self.crop_shape)):
if self.crop_shape[i] == -1:
crop_shape[i] = self.x_shape[i] - self.offsets[i]
self.outputs = {'Out': crop(self.inputs['X'], self.offsets, crop_shape)}
def initTestCase(self):
self.x_shape = (8, 8)
......@@ -93,9 +93,8 @@ class TestCase1(TestCropTensorOp):
class TestCase2(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (12, 24)
self.crop_shape = [-1, 8] #only the first dimension (batch) can be -1
self.crop_shape = [-1, 8]
self.offsets = [0, 0]
self.unk_dim_idx = 0
class TestCase3(TestCropTensorOp):
......@@ -103,16 +102,15 @@ class TestCase3(TestCropTensorOp):
self.x_shape = (4, 8, 16)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.crop_by_1D_shape = True
self.shape_by_input = True
class TestCase4(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (8, 3, 6, 6)
self.crop_shape = [-1, 3, 4, 4]
self.offsets = [0, 0, 0, 0]
self.crop_by_1D_shape = True
self.unk_dim_idx = 0
self.crop_shape = [-1, 3, -1, 4]
self.offsets = [0, 0, 1, 0]
self.shape_by_input = True
class TestCase5(TestCropTensorOp):
......@@ -128,14 +126,13 @@ class TestCase6(TestCropTensorOp):
self.x_shape = (2, 2, 4, 4, 4, 2)
self.crop_shape = [1, 1, 4, 2, 2, 2]
self.offsets = [0, 0, 0, 0, 0, 0]
self.crop_by_1D_shape = True
self.shape_by_input = True
self.offset_by_input = True
class TestCropTensorOp_attr_tensor(OpTest):
class TestCropTensorOpTensorAttr(OpTest):
def setUp(self):
self.op_type = "crop_tensor"
self.mixed_type = False
self.OffsetsTensor = False
self.ShapeTensor = True
self.attrs = {}
......@@ -150,8 +147,7 @@ class TestCropTensorOp_attr_tensor(OpTest):
'X': np.random.random(self.x_shape).astype("float32"),
'ShapeTensor': shape_tensor
}
if self.mixed_type:
self.attrs['shape'] = self.shape_attr
self.attrs['shape'] = self.shape_attr
if self.OffsetsTensor:
offsets_tensor = []
......@@ -162,17 +158,21 @@ class TestCropTensorOp_attr_tensor(OpTest):
'X': np.random.random(self.x_shape).astype("float32"),
'OffsetsTensor': offsets_tensor
}
else:
self.attrs['offsets'] = self.offsets
self.attrs['offsets'] = self.offsets_attr
self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
}
self.attrs['shape'] = self.crop_shape
self.attrs['offsets'] = self.offsets
crop_shape = [val for val in self.crop_shape]
for i in range(len(self.crop_shape)):
if self.crop_shape[i] == -1:
crop_shape[i] = self.x_shape[i] - self.offsets[i]
self.outputs = {'Out': crop(self.inputs['X'], self.offsets, crop_shape)}
def initTestCase(self):
self.x_shape = (8, 8)
self.crop_shape = (2, 2)
self.offsets = [1, 2]
self.shape_attr = [0, 0]
def test_check_output(self):
self.check_output()
......@@ -181,38 +181,85 @@ class TestCropTensorOp_attr_tensor(OpTest):
self.check_grad(["X"], "Out", max_relative_error=0.006)
class TestCropTensorOp_attr_tensor_case1(TestCropTensorOp_attr_tensor):
def init_data(self):
class TestCropTensorOpTensorAttrCase1(TestCropTensorOpTensorAttr):
def initTestCase(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.crop_shape = [-1, -1, 3]
self.offsets = [1, 5, 3]
self.shape_attr = [-1, -1, 3]
class TestCropTensorOp_attr_tensor_case2(TestCropTensorOp_attr_tensor):
def init_data(self):
class TestCropTensorOpTensorAttrCase2(TestCropTensorOpTensorAttr):
def initTestCase(self):
self.x_shape = (4, 8, 16, 8)
self.crop_shape = [2, 2, 3, 4]
self.offsets = [1, 5, 3, 0]
self.shape_attr = [-1, -1, 3, 4]
self.mixed_type = True
self.shape_attr = [0, 0, 3, 4]
class TestCropTensorOp_attr_tensor_case3(TestCropTensorOp_attr_tensor):
def init_data(self):
class TestCropTensorOpTensorAttrCase3(TestCropTensorOpTensorAttr):
def initTestCase(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.offsets_attr = [-1, -1, 3]
self.ShapeTensor = False
self.OffsetsTensor = True
class TestCropTensorOp_attr_tensor_case4(TestCropTensorOp_attr_tensor):
def init_data(self):
class TestCropTensorOpTensorAttrCase4(TestCropTensorOpTensorAttr):
def initTestCase(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.shape_attr = [0, 2, 3]
self.offsets = [1, 5, 3]
self.offsets_attr = [-1, -1, 3]
self.OffsetsTensor = True
class TestCropTensorException(OpTest):
def test_exception(self):
input1 = fluid.data(name="input1", shape=[2, 3, 6, 6], dtype="float32")
input2 = fluid.data(name="input2", shape=[2, 3, 6, 6], dtype="float16")
dim = fluid.data(name='dim', shape=[1], dtype='int32')
offset = fluid.data(name='offset', shape=[1], dtype='int32')
def attr_shape_type():
out = fluid.layers.crop_tensor(input1, shape=3)
def attr_shape_dtype():
out = fluid.layers.crop_tensor(input1, shape=[2, 2.0, 3, 3])
def attr_shape_value1():
out = fluid.layers.crop_tensor(input1, shape=[2, -2, dim, 3])
def attr_shape_value2():
out = fluid.layers.crop_tensor(input1, shape=[2, 0, dim, 3])
def attr_offsets_type():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=0)
def attr_offsets_dtype():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=[0, 1.0, 0, 0])
def attr_offsets_value():
out = fluid.layers.crop_tensor(
input1, shape=[2, 2, 3, 3], offsets=[0, -1, offset, 0])
def input_dtype():
out = fluid.layers.crop_tensor(input2, shape=[2, 2, 3, 3])
self.assertRaises(TypeError, attr_shape_type)
self.assertRaises(TypeError, attr_shape_dtype)
self.assertRaises(ValueError, attr_shape_value1)
self.assertRaises(ValueError, attr_shape_value2)
self.assertRaises(TypeError, attr_offsets_type)
self.assertRaises(TypeError, attr_offsets_dtype)
self.assertRaises(ValueError, attr_offsets_value)
self.assertRaises(TypeError, input_dtype)
if __name__ == '__main__':
unittest.main()
......@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest
......@@ -60,12 +62,15 @@ class TestLRNOp(OpTest):
'n': self.n,
'k': self.k,
'alpha': self.alpha,
'beta': self.beta
'beta': self.beta,
'data_format': self.data_format
}
return attrs
def setUp(self):
self.op_type = "lrn"
self.init_test_case()
self.N = 2
self.C = 3
self.H = 5
......@@ -77,11 +82,18 @@ class TestLRNOp(OpTest):
self.beta = 0.75
self.x = self.get_input()
self.out, self.mid_out = self.get_out()
if self.data_format == 'NHWC':
self.x = np.transpose(self.x, [0, 2, 3, 1])
self.out = np.transpose(self.out, [0, 2, 3, 1])
self.mid_out = np.transpose(self.mid_out, [0, 2, 3, 1])
self.inputs = {'X': self.x}
self.outputs = {'Out': self.out, 'MidOut': self.mid_out}
self.attrs = self.get_attrs()
def init_test_case(self):
self.data_format = 'NCHW'
def test_check_output(self):
self.check_output()
......@@ -89,5 +101,49 @@ class TestLRNOp(OpTest):
self.check_grad(['X'], 'Out', max_relative_error=0.01)
class TestLRNOpAttrDataFormat(TestLRNOp):
def init_test_case(self):
self.data_format = 'NHWC'
class TestLRNAPI(OpTest):
def test_case(self):
data1 = fluid.data(name='data1', shape=[2, 4, 5, 5], dtype='float32')
data2 = fluid.data(name='data2', shape=[2, 5, 5, 4], dtype='float32')
out1 = fluid.layers.lrn(data1, data_format='NCHW')
out2 = fluid.layers.lrn(data2, data_format='NHWC')
data1_np = np.random.random((2, 4, 5, 5)).astype("float32")
data2_np = np.transpose(data1_np, [0, 2, 3, 1])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2],
return_numpy=True)
self.assertTrue(
np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2))))
def test_exception(self):
input1 = fluid.data(name="input1", shape=[2, 4, 5, 5], dtype="float32")
input2 = fluid.data(
name="input2", shape=[2, 4, 5, 5, 5], dtype="float32")
def _attr_data_fromat():
out = fluid.layers.lrn(input1, data_format='NDHW')
def _input_dim_size():
out = fluid.layers.lrn(input2)
self.assertRaises(ValueError, _attr_data_fromat)
self.assertRaises(ValueError, _input_dim_size)
if __name__ == "__main__":
unittest.main()
......@@ -16,11 +16,16 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest
def maxout_forward_naive(input, groups):
def maxout_forward_naive(input, groups, channel_axis):
s0, s1, s2, s3 = input.shape
if channel_axis == 3:
return np.ndarray([s0, s1, s2, s3 // groups, groups], \
buffer = input, dtype=input.dtype).max(axis=(4))
return np.ndarray([s0, s1 // groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2))
......@@ -30,10 +35,11 @@ class TestMaxOutOp(OpTest):
self.op_type = "maxout"
self.init_test_case()
input = np.random.random(self.shape).astype("float32")
output = self.MaxOut_forward_naive(input, self.groups).astype("float32")
output = self.MaxOut_forward_naive(input, self.groups,
self.axis).astype("float32")
self.inputs = {'X': input}
self.attrs = {'groups': self.groups}
self.attrs = {'groups': self.groups, 'axis': self.axis}
self.outputs = {'Out': output.astype('float32')}
......@@ -47,6 +53,48 @@ class TestMaxOutOp(OpTest):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2]
self.groups = 2
self.axis = 1
class TestMaxOutOpAxis(TestMaxOutOp):
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 2, 2, 6] # NHWC format
self.groups = 2
self.axis = 3
class TestMaxOutOpAxisAPI(OpTest):
def test_axis(self):
data1 = fluid.data(name='data1', shape=[3, 6, 2, 2], dtype='float32')
data2 = fluid.data(name='data2', shape=[3, 2, 2, 6], dtype='float32')
out1 = fluid.layers.maxout(data1, groups=2, axis=1)
out2 = fluid.layers.maxout(data2, groups=2, axis=-1)
data1_np = np.random.random((3, 6, 2, 2)).astype("float32")
data2_np = np.transpose(data1_np, [0, 2, 3, 1])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2],
return_numpy=True)
self.assertTrue(
np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2))))
def test_exception(self):
input = fluid.data(name="input", shape=[2, 4, 6, 6], dtype="float32")
def _attr_axis():
out = fluid.layers.maxout(input, groups=2, axis=2)
self.assertRaises(ValueError, _attr_axis)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册