You need to sign in or sign up before continuing.
未验证 提交 f86fead6 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add trilinear_interp OP (#18711)

* add trilinear interp. test=develop

* fix unittest. test=develop

* add python api and test_layers. test=develop

* refine API.spec. test=develop

* fix format. test=develop

* add python API test. test=develop

* format code. test=develop

* refine code strcuture. test=develop

* fix format

* fix doc. test=develop

* fix converage. test=develop

* fix format. test=develop
上级 c2063217
...@@ -173,10 +173,11 @@ paddle.fluid.layers.label_smooth (ArgSpec(args=['label', 'prior_dist', 'epsilon' ...@@ -173,10 +173,11 @@ paddle.fluid.layers.label_smooth (ArgSpec(args=['label', 'prior_dist', 'epsilon'
paddle.fluid.layers.roi_pool (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)), ('document', '49368d724023a66b41b0071be41c0ba5')) paddle.fluid.layers.roi_pool (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)), ('document', '49368d724023a66b41b0071be41c0ba5'))
paddle.fluid.layers.roi_align (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)), ('document', '9a7a3b88a4fae41d58d3ca9b10ba0591')) paddle.fluid.layers.roi_align (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)), ('document', '9a7a3b88a4fae41d58d3ca9b10ba0591'))
paddle.fluid.layers.dice_loss (ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)), ('document', '7e8e4bf1f0f8612961ed113e8af8f0c5')) paddle.fluid.layers.dice_loss (ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)), ('document', '7e8e4bf1f0f8612961ed113e8af8f0c5'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', 'a29488d94d9a4bc4434d8a3529b4c6fe')) paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', '8cfc4f69dbbedb687b6c20732aa8f09e'))
paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', 'bd97ebfe4bdf5110a5fcb8ecb626a447')) paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', 'bd97ebfe4bdf5110a5fcb8ecb626a447'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '548c7c2ead5771d15abbaad505f901e9')) paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '832b2412652d84a6631b1012c6e2d18b'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', 'b7d810d1e251c5957c1efa6aa699d2d0')) paddle.fluid.layers.resize_trilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '4836e98a634f6fbea26d0cdaa303f867'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '32ffc0e8818d7319ed1bf63a791e985d'))
paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c')) paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c'))
paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab')) paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab'))
paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'abe3f714120117a5a3d3e639853932bf')) paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'abe3f714120117a5a3d3e639853932bf'))
......
...@@ -20,6 +20,85 @@ namespace operators { ...@@ -20,6 +20,85 @@ namespace operators {
using framework::Tensor; using framework::Tensor;
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4");
int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
// protect when input shape is -1
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
ctx->ShareLoD("X", "Out");
return;
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}
static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE("trilinear" == interp_method,
"Interpolation method can only be \"trilinear\" when Input(X) "
"dimension is 5");
int out_d, out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_d = static_cast<int>(dim_x[2] * scale);
out_h = static_cast<int>(dim_x[3] * scale);
out_w = static_cast<int>(dim_x[4] * scale);
// protect when input shape is -1
out_d = out_d > 0 ? out_d : -1;
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_d = ctx->Attrs().Get<int>("out_d");
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_d, 0, "out_d should be greater than 0.");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 3, "OutSize's dim[0] must be 3");
ctx->ShareLoD("X", "Out");
return;
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_d, out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}
class InterpolateOp : public framework::OperatorWithKernel { class InterpolateOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -31,41 +110,17 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -31,41 +110,17 @@ class InterpolateOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of InterpolationOp should not be null."); "Output(Out) of InterpolationOp should not be null.");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\".");
auto dim_x = ctx->GetInputDim("X"); // NCHW format auto dim_x = ctx->GetInputDim("X"); // NCHW format
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); PADDLE_ENFORCE(dim_x.size() == 4 || dim_x.size() == 5,
"Input(X) dimension must be 4 or 5");
int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale"); if (dim_x.size() == 4) {
if (scale > 0) { // shape check for 2D interpolate for input tensor shape NCHW
// round down Interpolate2DInferShapeCheck(ctx);
out_h = static_cast<int>(dim_x[2] * scale); } else { // dim_x.size() == 5
out_w = static_cast<int>(dim_x[3] * scale); // shape check for 3D interpolate for input tensor shape NCDHW
// protect when input shape is -1 Interpolate3DInferShapeCheck(ctx);
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
ctx->ShareLoD("X", "Out");
return;
} }
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
} }
protected: protected:
...@@ -81,22 +136,27 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -81,22 +136,27 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"The input tensor of interpolate operator, " "The input tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, w]."); "This is a 4-D tensor with shape of [N, C, H, W] or a "
"5-D tensor with shape of [N, C, D, H, W].");
AddInput("OutSize", AddInput("OutSize",
"This is a 1-D tensor with two numbers to specify output size. " "This is a 1-D tensor with two numbers to specify output size. "
"The first number is height and the second number is width.") "It should be [output_height, output_width] when input is a 4-D "
"tensor and should be [output_depth, output_height, output_width] "
"when input is a 5-D tensor.")
.AsDispensable(); .AsDispensable();
AddOutput("Out", AddOutput("Out",
"The output tensor of interpolate operator, " "The output tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, W]."); "This is a tensor in same rank with Input(X).");
AddAttr<int>("out_h", "output height of interpolate op."); AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
AddAttr<int>("out_w", "output width of interpolate op."); AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.); AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
AddAttr<std::string>("interp_method", AddAttr<std::string>("interp_method",
"(string, default \"bilinear\"), interpolation " "(string, default \"bilinear\"), interpolation "
"method, can be \"bilinear\" for " "method, can be \"bilinear\" for "
"bilinear interpolation and \"nearest\" for nearest " "bilinear interpolation, \"trilinear\" for trilinear "
"interpolation and \"nearest\" for nearest "
"neighbor interpolation.") "neighbor interpolation.")
.SetDefault("bilinear"); .SetDefault("bilinear");
AddAttr<bool>( AddAttr<bool>(
...@@ -127,6 +187,11 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -127,6 +187,11 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
to perform linear interpolation first in one direction, and then to perform linear interpolation first in one direction, and then
again in the other direction. again in the other direction.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Align_corners and align_mode are optinal parameters,the calculation method Align_corners and align_mode are optinal parameters,the calculation method
of interpolation can be selected by them. of interpolation can be selected by them.
...@@ -183,6 +248,27 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -183,6 +248,27 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
H_out = H_{in} * scale_{factor} H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor} W_out = W_{in} * scale_{factor}
Trilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia: For details of nearest neighbor interpolation, please refer to Wikipedia:
...@@ -190,6 +276,9 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -190,6 +276,9 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
For details of bilinear interpolation, please refer to Wikipedia: For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation https://en.wikipedia.org/wiki/Bilinear_interpolation
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation
)DOC"); )DOC");
} }
}; };
...@@ -251,6 +340,10 @@ REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker, ...@@ -251,6 +340,10 @@ REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradDescMaker); ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad, REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference); ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>, REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>, ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>); ops::InterpolateKernel<uint8_t>);
...@@ -261,3 +354,8 @@ REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>, ...@@ -261,3 +354,8 @@ REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<uint8_t>); ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>, REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>); ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
...@@ -191,80 +191,483 @@ __global__ void KeBilinearInterpBw( ...@@ -191,80 +191,483 @@ __global__ void KeBilinearInterpBw(
} }
template <typename T> template <typename T>
class InterpolateOpCUDAKernel : public framework::OpKernel<T> { __global__ void KeTrilinearInterpFw(
public: const T* in, const size_t in_img_d, const size_t in_img_h,
void Compute(const framework::ExecutionContext& ctx) const override { const size_t in_img_w, const size_t input_h, const size_t input_w, T* out,
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
"This kernel only runs on GPU device."); const size_t output_h, const size_t output_w, const size_t num_channels,
auto* input = ctx.Input<Tensor>("X"); const float ratio_d, const float ratio_h, const float ratio_w,
auto* output = ctx.Output<Tensor>("Out"); const bool align_corners, const int align_mode) {
auto* input_data = input->data<T>(); int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id = out_id_w / out_img_size;
int n = input->dims()[0]; int out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
int c = input->dims()[1]; int in_img_idt = align_flag
int in_h = input->dims()[2]; ? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
int in_w = input->dims()[3]; : static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
int out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
auto interp_method = ctx.Attr<std::string>("interp_method"); int out_img_idx = tid % out_img_w;
int out_h = ctx.Attr<int>("out_h"); int in_img_idx = align_flag
int out_w = ctx.Attr<int>("out_w"); ? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
float scale = ctx.Attr<float>("scale"); int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
if (scale > 0) { (in_img_idt * in_img_h + in_img_idy) * in_img_w +
out_h = in_h * scale; in_img_idx;
out_w = in_w * scale; const T* in_pos1 = &in[in_pos1_idx];
} int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
const T* in_pos2 = &in[in_pos2_idx];
auto out_size = ctx.Input<Tensor>("OutSize"); // trilinear interpolation
if (out_size != nullptr) { out[out_id_h * output_w + out_id_w] =
Tensor sizes; d2lambda *
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); (h2lambda * (w2lambda * in_pos1[0] + w1lambda * in_pos1[w_id]) +
auto size_data = sizes.data<int>(); h1lambda * (w2lambda * in_pos1[h_id * in_img_w] +
out_h = size_data[0]; w1lambda * in_pos1[h_id * in_img_w + w_id])) +
out_w = size_data[1]; d1lambda *
} (h2lambda * (w2lambda * in_pos2[0] + w1lambda * in_pos2[w_id]) +
h1lambda * (w2lambda * in_pos2[h_id * in_img_w] +
w1lambda * in_pos2[h_id * in_img_w + w_id]));
}
}
bool align_corners = ctx.Attr<bool>("align_corners"); template <typename T>
int align_mode = ctx.Attr<int>("align_mode"); __global__ void KeTrilinearInterpBw(
T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, const T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const T ratio_d, const T ratio_h, const T ratio_w, const bool align_corners,
const int align_mode) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id = out_id_w / out_img_size;
auto* output_data = int out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace()); int in_img_idt = align_flag
? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
: static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
int out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int out_img_idx = tid % out_img_w;
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
int in_hw = in_h * in_w; int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
int out_hw = out_h * out_w; (in_img_idt * in_img_h + in_img_idy) * in_img_w +
int in_chw = c * in_hw; in_img_idx;
int out_chw = c * out_hw; T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
T* in_pos2 = &in[in_pos2_idx];
float ratio_h = 0.f; const T* out_pos = &out[out_id_h * output_w + out_id_w];
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if (in_h == out_h && in_w == out_w) { // trilinear interpolation grad
framework::TensorCopy(*input, ctx.GetPlace(), output); platform::CudaAtomicAdd(&in_pos1[0],
return; d2lambda * h2lambda * w2lambda * out_pos[0]);
} platform::CudaAtomicAdd(&in_pos1[w_id],
d2lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w],
d2lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w + w_id],
d2lambda * h1lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[0],
d1lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[w_id],
d1lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w],
d1lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w + w_id],
d1lambda * h1lambda * w1lambda * out_pos[0]);
}
}
int pixelNum = n * out_chw; template <typename T>
int grid_dim = (pixelNum + 512 - 1) / 512; static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
grid_dim = grid_dim > 8 ? 8 : grid_dim; const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
if ("nearest" == interp_method) {
KeNearestNeighborInterpFw< const int n = input.dims()[0];
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( const int c = input.dims()[1];
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, const int in_h = input.dims()[2];
out_chw, c, ratio_h, ratio_w, align_corners); const int in_w = input.dims()[3];
} else if ("bilinear" == interp_method) {
KeBilinearInterpFw< auto interp_method = ctx.Attr<std::string>("interp_method");
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( bool align_corners = ctx.Attr<bool>("align_corners");
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, int align_mode = ctx.Attr<int>("align_mode");
out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
auto output_data =
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners);
} else if ("bilinear" == interp_method) {
KeBilinearInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
}
}
template <typename T>
static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_d = input.dims()[2];
const int in_h = input.dims()[3];
const int in_w = input.dims()[4];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
auto output_data =
output->mutable_data<T>({n, c, out_d, out_h, out_w}, ctx.GetPlace());
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_dhw = in_d * in_h * in_w;
int out_dhw = out_d * out_h * out_w;
int in_cdhw = c * in_dhw;
int out_cdhw = c * out_dhw;
int pixelNum = n * out_cdhw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("trilinear" == interp_method) {
KeTrilinearInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode);
}
}
template <typename T>
static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
auto* output_grad_data = output_grad.data<T>();
auto* input_grad_data =
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) {
KeNearestNeighborInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
}
}
template <typename T>
static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad,
const Tensor& output_grad) {
auto* input = ctx.Input<Tensor>("X");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_d = input->dims()[2];
const int in_h = input->dims()[3];
const int in_w = input->dims()[4];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
auto* output_grad_data = output_grad.data<T>();
auto* input_grad_data =
input_grad->mutable_data<T>({n, c, in_d, in_h, in_w}, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_dhw = in_d * in_h * in_w;
int out_dhw = out_d * out_h * out_w;
int in_cdhw = c * in_dhw;
int out_cdhw = c * out_dhw;
int pixelNum = n * out_cdhw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("trilinear" == interp_method) {
KeTrilinearInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode);
}
}
template <typename T>
class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
if (input_dims.size() == 4) { // 2D interpolation
Interpolate2DCUDAFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDAFwd<T>(ctx, *input, output);
} }
} }
}; };
...@@ -273,76 +676,16 @@ template <typename T> ...@@ -273,76 +676,16 @@ template <typename T>
class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> { class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* output_grad_data = output_grad->data<T>();
auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
int n = input_grad->dims()[0];
int c = input_grad->dims()[1];
int in_h = input_grad->dims()[2];
int in_w = input_grad->dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w = in_w * scale;
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int in_hw = in_h * in_w; auto output_grad_dims = output_grad->dims();
int out_hw = out_h * out_w; if (output_grad_dims.size() == 4) { // 2D interpolation
int in_chw = c * in_hw; Interpolate2DCUDABwd<T>(ctx, input_grad, *output_grad);
int out_chw = c * out_hw; } else if (output_grad_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDABwd<T>(ctx, input_grad, *output_grad);
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
return;
}
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) {
KeNearestNeighborInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w, align_corners);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
} }
} }
}; };
...@@ -363,3 +706,9 @@ REGISTER_OP_CUDA_KERNEL(nearest_interp, ops::InterpolateOpCUDAKernel<float>, ...@@ -363,3 +706,9 @@ REGISTER_OP_CUDA_KERNEL(nearest_interp, ops::InterpolateOpCUDAKernel<float>,
REGISTER_OP_CUDA_KERNEL(nearest_interp_grad, REGISTER_OP_CUDA_KERNEL(nearest_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>, ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>); ops::InterpolateGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(trilinear_interp, ops::InterpolateOpCUDAKernel<float>,
ops::InterpolateOpCUDAKernel<double>,
ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(trilinear_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
...@@ -131,6 +131,128 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, ...@@ -131,6 +131,128 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output,
} }
} }
template <typename T>
static void TrilinearInterpolation(
const Tensor& input, Tensor* output, const float ratio_d,
const float ratio_h, const float ratio_w, const int in_d, const int in_h,
const int in_w, const int n, const int c, const int out_d, const int out_h,
const int out_w, const bool align_corners, const bool align_mode) {
auto input_t = EigenTensor<T, 5>::From(input);
auto output_t = EigenTensor<T, 5>::From(*output);
bool align_flag = (align_mode == 0 && !align_corners);
std::vector<int> vt_f, vt_b;
std::vector<float> vd_f, vd_b;
vt_f.reserve(out_d);
vt_b.reserve(out_d);
vd_f.reserve(out_d);
vd_b.reserve(out_d);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int j = 0; j < out_d; j++) {
int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
: static_cast<int>(ratio_d * j);
t_f = (t_f > 0) ? t_f : 0;
int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
float idx_src_t = ratio_d * (j + 0.5) - 0.5;
idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
float d_b = 1.f - d_f;
{
vt_f[j] = t_f;
vt_b[j] = t_b;
vd_f[j] = d_f;
vd_b[j] = d_b;
}
}
std::vector<int> vy_n, vy_s;
std::vector<float> vd_n, vd_s;
vy_n.reserve(out_h);
vy_s.reserve(out_h);
vd_n.reserve(out_h);
vd_s.reserve(out_h);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int k = 0; k < out_h; k++) {
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
{
vy_n[k] = y_n;
vy_s[k] = y_s;
vd_n[k] = d_n;
vd_s[k] = d_s;
}
}
std::vector<int> vx_w, vx_e;
std::vector<float> vd_w, vd_e;
vx_w.reserve(out_w);
vx_e.reserve(out_w);
vd_w.reserve(out_w);
vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int l = 0; l < out_w; l++) {
int x_w = (align_mode == 0 && !align_corners)
? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
{
vx_w[l] = x_w;
vx_e[l] = x_e;
vd_w[l] = d_w;
vd_e[l] = d_e;
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(5)
#endif
for (int b = 0; b < n; b++) { // loop for batches
for (int i = 0; i < c; i++) { // loop for channels
for (int j = 0; j < out_d; j++) { // loop for D, H, W
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
// trilinear interpolation
T out_t = input_t(b, i, vt_f[j], vy_n[k], vx_w[l]) * vd_b[j] *
vd_s[k] * vd_e[l] +
input_t(b, i, vt_f[j], vy_n[k], vx_e[l]) * vd_b[j] *
vd_s[k] * vd_w[l] +
input_t(b, i, vt_f[j], vy_s[k], vx_w[l]) * vd_b[j] *
vd_n[k] * vd_e[l] +
input_t(b, i, vt_f[j], vy_s[k], vx_e[l]) * vd_b[j] *
vd_n[k] * vd_w[l] +
input_t(b, i, vt_b[j], vy_n[k], vx_w[l]) * vd_f[j] *
vd_s[k] * vd_e[l] +
input_t(b, i, vt_b[j], vy_n[k], vx_e[l]) * vd_f[j] *
vd_s[k] * vd_w[l] +
input_t(b, i, vt_b[j], vy_s[k], vx_w[l]) * vd_f[j] *
vd_n[k] * vd_e[l] +
input_t(b, i, vt_b[j], vy_s[k], vx_e[l]) * vd_f[j] *
vd_n[k] * vd_w[l];
output_t(b, i, j, k, l) = out_t;
}
}
}
}
}
}
template <typename T> template <typename T>
static void NearestNeighborInterpolateGrad( static void NearestNeighborInterpolateGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h, const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
...@@ -200,134 +322,340 @@ static void BilinearInterpolationGrad(const Tensor& output_grad, ...@@ -200,134 +322,340 @@ static void BilinearInterpolationGrad(const Tensor& output_grad,
} }
} }
} }
template <typename T> template <typename T>
class InterpolateKernel : public framework::OpKernel<T> { static void TrilinearInterpolationGrad(
public: const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
void Compute(const framework::ExecutionContext& ctx) const override { const float ratio_h, const float ratio_w, const int in_d, const int in_h,
auto* input = ctx.Input<Tensor>("X"); const int in_w, const int n, const int c, const int out_d, const int out_h,
auto* output = ctx.Output<Tensor>("Out"); const int out_w, const bool align_corners, const int align_mode) {
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int j = 0; j < out_d; j++) { // loop for D
int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
: static_cast<int>(ratio_d * j);
t_f = (t_f > 0) ? t_f : 0;
int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
float idx_src_t = ratio_d * (j + 0.5) - 0.5;
idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
float d_b = 1.f - d_f;
for (int k = 0; k < out_h; k++) { // loop for H
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) { // loop for W
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int b = 0; b < n; b++) { // loop for batches
for (int i = 0; i < c; i++) { // loop for channels
// trilinear interpolation grad
const T grad = output_grad_t(b, i, j, k, l);
input_grad_t(b, i, t_f, y_n, x_w) +=
static_cast<T>(grad * d_b * d_s * d_e);
input_grad_t(b, i, t_f, y_n, x_e) +=
static_cast<T>(grad * d_b * d_s * d_w);
input_grad_t(b, i, t_f, y_s, x_w) +=
static_cast<T>(grad * d_b * d_n * d_e);
input_grad_t(b, i, t_f, y_s, x_e) +=
static_cast<T>(grad * d_b * d_n * d_w);
input_grad_t(b, i, t_b, y_n, x_w) +=
static_cast<T>(grad * d_f * d_s * d_e);
input_grad_t(b, i, t_b, y_n, x_e) +=
static_cast<T>(grad * d_f * d_s * d_w);
input_grad_t(b, i, t_b, y_s, x_w) +=
static_cast<T>(grad * d_f * d_n * d_e);
input_grad_t(b, i, t_b, y_s, x_e) +=
static_cast<T>(grad * d_f * d_n * d_w);
}
}
}
}
}
}
const int n = input->dims()[0]; template <typename T>
const int c = input->dims()[1]; static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
const int in_h = input->dims()[2]; const Tensor& input, Tensor* output) {
const int in_w = input->dims()[3]; const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_h = input.dims()[2];
const int in_w = input.dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
std::string interp_method = ctx.Attr<std::string>("interp_method"); auto out_size = ctx.Input<Tensor>("OutSize");
int out_h = ctx.Attr<int>("out_h"); if (out_size != nullptr) {
int out_w = ctx.Attr<int>("out_w"); auto out_size_data = out_size->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
float scale = ctx.Attr<float>("scale"); output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize"); if (in_h == out_h && in_w == out_w) {
if (out_size != nullptr) { framework::TensorCopy(input, ctx.GetPlace(), output);
auto out_size_data = out_size->data<int>(); return;
out_h = out_size_data[0]; }
out_w = out_size_data[1];
}
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, output, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*input, ctx.GetPlace(), output);
return;
}
float ratio_h = 0.f; float ratio_h = 0.f;
float ratio_w = 0.f; float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if (out_h > 1) { if ("bilinear" == interp_method) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1) BilinearInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
: static_cast<float>(in_h) / out_h; out_h, out_w, align_corners, align_mode);
} } else if ("nearest" == interp_method) {
if (out_w > 1) { NearestNeighborInterpolate<T>(input, output, ratio_h, ratio_w, n, c, out_h,
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1) out_w, align_corners);
: static_cast<float>(in_w) / out_w; }
} }
if ("bilinear" == interp_method) { template <typename T>
BilinearInterpolation<T>(*input, output, ratio_h, ratio_w, in_h, in_w, n, static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
c, out_h, out_w, align_corners, align_mode); const Tensor& input, Tensor* output) {
} else if ("nearest" == interp_method) { const int n = input.dims()[0];
NearestNeighborInterpolate<T>(*input, output, ratio_h, ratio_w, n, c, const int c = input.dims()[1];
out_h, out_w, align_corners); const int in_d = input.dims()[2];
} const int in_h = input.dims()[3];
const int in_w = input.dims()[4];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_d = out_size_data[0];
out_h = out_size_data[1];
out_w = out_size_data[2];
}
output->mutable_data<T>({n, c, out_d, out_h, out_w}, ctx.GetPlace());
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
} }
};
if ("trilinear" == interp_method) {
TrilinearInterpolation<T>(input, output, ratio_d, ratio_h, ratio_w, in_d,
in_h, in_w, n, c, out_d, out_h, out_w,
align_corners, align_mode);
}
}
template <typename T> template <typename T>
class InterpolateGradKernel : public framework::OpKernel<T> { static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
public: Tensor* input_grad, const Tensor& output_grad) {
void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input<Tensor>("X");
auto* input = ctx.Input<Tensor>("X"); const int n = input->dims()[0];
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); const int c = input->dims()[1];
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
const int n = input->dims()[0]; auto out_size = ctx.Input<Tensor>("OutSize");
const int c = input->dims()[1]; if (out_size != nullptr) {
const int in_h = input->dims()[2]; auto out_size_data = out_size->data<int>();
const int in_w = input->dims()[3]; out_h = out_size_data[0];
out_w = out_size_data[1];
}
std::string interp_method = ctx.Attr<std::string>("interp_method"); input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
int out_h = ctx.Attr<int>("out_h"); auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
int out_w = ctx.Attr<int>("out_w"); math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
float scale = ctx.Attr<float>("scale"); if (in_h == out_h && in_w == out_w) {
if (scale > 0) { framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
out_h = static_cast<int>(in_h * scale); return;
out_w = static_cast<int>(in_w * scale); }
}
auto out_size = ctx.Input<Tensor>("OutSize"); float ratio_h = 0.f;
if (out_size != nullptr) { float ratio_w = 0.f;
auto out_size_data = out_size->data<int>(); if (out_h > 1) {
out_h = out_size_data[0]; ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
out_w = out_size_data[1]; : static_cast<float>(in_h) / out_h;
} }
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("bilinear" == interp_method) {
BilinearInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
in_h, in_w, n, c, out_h, out_w, align_corners,
align_mode);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolateGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
n, c, out_h, out_w, align_corners);
}
}
bool align_corners = ctx.Attr<bool>("align_corners"); template <typename T>
int align_mode = ctx.Attr<int>("align_mode"); static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_d = input->dims()[2];
const int in_h = input->dims()[3];
const int in_w = input->dims()[4];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace()); auto out_size = ctx.Input<Tensor>("OutSize");
auto& device_ctx = if (out_size != nullptr) {
ctx.template device_context<platform::CPUDeviceContext>(); auto out_size_data = out_size->data<int>();
math::SetConstant<platform::CPUDeviceContext, T> zero; out_d = out_size_data[0];
zero(device_ctx, input_grad, static_cast<T>(0.0)); out_h = out_size_data[1];
out_w = out_size_data[2];
}
if (in_h == out_h && in_w == out_w) { input_grad->mutable_data<T>({n, c, in_d, in_h, in_w}, ctx.GetPlace());
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
return; math::SetConstant<platform::CPUDeviceContext, T> zero;
} zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_h = 0.f; float ratio_d = 0.f;
float ratio_w = 0.f; float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if (out_h > 1) { if ("trilinear" == interp_method) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1) TrilinearInterpolationGrad<T>(output_grad, input_grad, ratio_d, ratio_h,
: static_cast<float>(in_h) / out_h; ratio_w, in_d, in_h, in_w, n, c, out_d, out_h,
} out_w, align_corners, align_mode);
if (out_w > 1) { }
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1) }
: static_cast<float>(in_w) / out_w;
template <typename T>
class InterpolateKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
if (input_dims.size() == 4) { // 2D interpolation
Interpolate2DCPUFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 5) { // 3D interpolation
Interpolate3DCPUFwd<T>(ctx, *input, output);
} }
}
};
template <typename T>
class InterpolateGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
if ("bilinear" == interp_method) { auto output_grad_dims = output_grad->dims();
BilinearInterpolationGrad<T>(*output_grad, input_grad, ratio_h, ratio_w, if (output_grad_dims.size() == 4) { // 2D interpolation grad
in_h, in_w, n, c, out_h, out_w, Interpolate2DCPUBwd<T>(ctx, input_grad, *output_grad);
align_corners, align_mode); } else if (output_grad_dims.size() == 5) { // 3D interpolation grad
} else if ("nearest" == interp_method) { Interpolate3DCPUBwd<T>(ctx, input_grad, *output_grad);
NearestNeighborInterpolateGrad<T>(*output_grad, input_grad, ratio_h,
ratio_w, n, c, out_h, out_w,
align_corners);
} }
} }
}; };
......
...@@ -119,6 +119,7 @@ __all__ = [ ...@@ -119,6 +119,7 @@ __all__ = [
'image_resize', 'image_resize',
'image_resize_short', 'image_resize_short',
'resize_bilinear', 'resize_bilinear',
'resize_trilinear',
'resize_nearest', 'resize_nearest',
'gather', 'gather',
'scatter', 'scatter',
...@@ -7672,13 +7673,16 @@ def image_resize(input, ...@@ -7672,13 +7673,16 @@ def image_resize(input,
""" """
**Resize a Batch of Images** **Resize a Batch of Images**
The input must be a tensor of the shape (num_batches, channels, in_h, in_w), The input must be a tensor of the shape (num_batches, channels, in_h, in_w)
and the resizing only applies on the last two dimensions(hight and width). or (num_batches, channels, in_d, in_h, in_w), and the resizing only applies
on the last two/three dimensions(depth, hight and width).
Supporting resample methods: Supporting resample methods:
'BILINEAR' : Bilinear interpolation 'BILINEAR' : Bilinear interpolation
'TRILINEAR' : Trilinear interpolation
'NEAREST' : Nearest neighbor interpolation 'NEAREST' : Nearest neighbor interpolation
Nearest neighbor interpolation is to perform nearest neighbor interpolation Nearest neighbor interpolation is to perform nearest neighbor interpolation
...@@ -7691,6 +7695,11 @@ def image_resize(input, ...@@ -7691,6 +7695,11 @@ def image_resize(input,
to perform linear interpolation first in one direction, and then to perform linear interpolation first in one direction, and then
again in the other direction. again in the other direction.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Align_corners and align_mode are optinal parameters,the calculation method Align_corners and align_mode are optinal parameters,the calculation method
of interpolation can be selected by them. of interpolation can be selected by them.
...@@ -7748,30 +7757,58 @@ def image_resize(input, ...@@ -7748,30 +7757,58 @@ def image_resize(input,
H_out = H_{in} * scale_{factor} H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor} W_out = W_{in} * scale_{factor}
Trilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia: For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.
For details of bilinear interpolation, please refer to Wikipedia: For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation. https://en.wikipedia.org/wiki/Bilinear_interpolation.
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation.
Args: Args:
input (Variable): The input tensor of image resize layer, input (Variable): The input tensor of image resize layer,
This is a 4-D tensor of the shape This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w). (num_batches, channels, in_h, in_w) or a
5-D tensor of the shape
(num_batches, channls, in_d, in_h, in_w).
out_shape(list|tuple|Variable|None): Output shape of image resize out_shape(list|tuple|Variable|None): Output shape of image resize
layer, the shape is (out_h, out_w). layer, the shape is (out_h, out_w) when
Default: None input is a 4-D tensor and is
(out_d, out_h, out_w) when input is a
5-D tensor. Default: None
scale(float|None): The multiplier for the input height or width. At scale(float|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale` must be set. least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`. And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None. Default: None.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST' resample(str): The resample method. It supports 'BILINEAR', 'TRILINEAR'
currently. and 'NEAREST' currently. Default: 'BILINEAR'
Default: 'BILINEAR'
actual_shape(Variable): An optional input to specify output shape actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize dynamically. If provided, image resize
according to this given shape rather than according to this given shape rather than
...@@ -7795,15 +7832,19 @@ def image_resize(input, ...@@ -7795,15 +7832,19 @@ def image_resize(input,
Returns: Returns:
Variable: The output is a 4-D tensor of the shape Variable: The output is a 4-D tensor of the shape
(num_batches, channls, out_h, out_w). (num_batches, channls, out_h, out_w) or a 5-D tensor of the shape
(num_batches, channels, out_d, out_h, out_w).
Raises: Raises:
TypeError: out_shape should be a list or tuple or Variable. TypeError: out_shape should be a list or tuple or Variable.
TypeError: actual_shape should either be Variable or None. TypeError: actual_shape should either be Variable or None.
ValueError: The 'resample' of image_resize can only be 'BILINEAR' ValueError: The 'resample' of image_resize can only be 'BILINEAR',
or 'NEAREST' currently. 'TRILINEAR' or 'NEAREST' currently.
ValueError: 'BILINEAR' and 'NEAREST' only support 4-D tensor.
ValueError: 'TRILINEAR' only support 5-D tensor.
ValueError: One of out_shape and scale must not be None. ValueError: One of out_shape and scale must not be None.
ValueError: out_shape length should be 2. ValueError: out_shape length should be 2 for input 4-D tensor.
ValueError: out_shape length should be 3 for input 5-D tensor.
ValueError: scale should be greater than zero. ValueError: scale should be greater than zero.
TypeError: align_corners shoule be a bool value TypeError: align_corners shoule be a bool value
ValueError: align_mode can only be '0' or '1' ValueError: align_mode can only be '0' or '1'
...@@ -7817,14 +7858,20 @@ def image_resize(input, ...@@ -7817,14 +7858,20 @@ def image_resize(input,
""" """
resample_methods = { resample_methods = {
'BILINEAR': 'bilinear', 'BILINEAR': 'bilinear',
'TRILINEAR': 'trilinear',
'NEAREST': 'nearest', 'NEAREST': 'nearest',
} }
if resample not in resample_methods: if resample not in resample_methods:
raise ValueError( raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently." "The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR' "
) "or 'NEAREST' currently.")
resample_type = resample_methods[resample] resample_type = resample_methods[resample]
if resample in ['BILINEAR', 'NEAREST'] and len(input.shape) != 4:
raise ValueError("'BILINEAR' and 'NEAREST' only support 4-D tensor.")
if resample == 'TRILINEAR' and len(input.shape) != 5:
raise ValueError("'TRILINEAR'only support 5-D tensor.")
if not isinstance(align_corners, bool): if not isinstance(align_corners, bool):
raise TypeError("Attr align_corners should be a bool value") raise TypeError("Attr align_corners should be a bool value")
if align_mode != 0 and align_mode != 1: if align_mode != 0 and align_mode != 1:
...@@ -7840,6 +7887,7 @@ def image_resize(input, ...@@ -7840,6 +7887,7 @@ def image_resize(input,
inputs = {"X": input} inputs = {"X": input}
attrs = { attrs = {
"out_d": 0,
"out_h": 0, "out_h": 0,
"out_w": 0, "out_w": 0,
"interp_method": resample_type, "interp_method": resample_type,
...@@ -7857,12 +7905,21 @@ def image_resize(input, ...@@ -7857,12 +7905,21 @@ def image_resize(input,
if not (_is_list_or_turple_(out_shape)): if not (_is_list_or_turple_(out_shape)):
raise TypeError( raise TypeError(
"out_shape should be a list or tuple or Variable.") "out_shape should be a list or tuple or Variable.")
if len(out_shape) != 2: if len(input.shape) == 4:
raise ValueError("out_shape length should be 2.") if len(out_shape) != 2:
raise ValueError("out_shape length should be 2 for "
out_shape = list(map(int, out_shape)) "input 4-D tensor.")
attrs['out_h'] = out_shape[0] out_shape = list(map(int, out_shape))
attrs['out_w'] = out_shape[1] attrs['out_h'] = out_shape[0]
attrs['out_w'] = out_shape[1]
if len(input.shape) == 5:
if len(out_shape) != 3:
raise ValueError("out_shape length should be 3 for "
"input 5-D tensor.")
out_shape = list(map(int, out_shape))
attrs['out_d'] = out_shape[0]
attrs['out_h'] = out_shape[1]
attrs['out_w'] = out_shape[2]
else: else:
if scale <= 0: if scale <= 0:
...@@ -7945,7 +8002,7 @@ def resize_bilinear(input, ...@@ -7945,7 +8002,7 @@ def resize_bilinear(input,
Args: Args:
input(${x_type}): ${x_comment}. input(${x_type}): input should be a 4-D tensor.
out_shape(list|tuple|Variable|None): Output shape of resize bilinear out_shape(list|tuple|Variable|None): Output shape of resize bilinear
layer, the shape is (out_h, out_w). layer, the shape is (out_h, out_w).
...@@ -7974,7 +8031,7 @@ def resize_bilinear(input, ...@@ -7974,7 +8031,7 @@ def resize_bilinear(input,
align_mode(bool): ${align_mode_comment} align_mode(bool): ${align_mode_comment}
Returns: Returns:
${out_comment}. A 4-D tensor in shape of (num_batches, channels, out_h, out_w)
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -7988,6 +8045,112 @@ def resize_bilinear(input, ...@@ -7988,6 +8045,112 @@ def resize_bilinear(input,
align_corners, align_mode) align_corners, align_mode)
@templatedoc(op_type="trilinear_interp")
def resize_trilinear(input,
out_shape=None,
scale=None,
name=None,
actual_shape=None,
align_corners=True,
align_mode=1):
"""
Resize input by performing trilinear interpolation based on given
output shape which specified by actual_shape, out_shape and scale
in priority order.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation
Align_corners and align_mode are optinal parameters,the calculation
method of interpolation can be selected by them.
Example:
.. code-block:: text
For scale:
if align_corners = True && out_size > 1 :
scale_factor = (in_size-1.0)/(out_size-1.0)
else:
scale_factor = float(in_size/out_size)
Bilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Args:
input(${x_type}): input should be a 4-D tensor.
out_shape(list|tuple|Variable|None): Output shape of resize bilinear
layer, the shape is (out_d, out_h, out_w).
Default: None
scale(float|None): The multiplier for the input depth, height or width.
At least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None.
name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
according to this given shape rather than
:attr:`out_shape` and :attr:`scale` specifying
shape. That is to say actual_shape has the
highest priority. It is recommended to use
actual_shape instead of :attr:`out_shape` if you
want to specify output shape dynamically. When
using actual_shape to specify output shape, one of
:attr:`out_shape` and :attr:`scale` should also be
set, otherwise errors would be occured in graph
constructing stage.
Default: None
align_corners(bool): ${align_corners_comment}
align_mode(bool): ${align_mode_comment}
Returns:
A 5-D tensor in shape (num_batches, channels, out_d, out_h, out_w)
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name="input", shape=[3,6,9,11], dtype="float32")
out = fluid.layers.resize_trilinear(input, out_shape=[12, 12, 12])
"""
return image_resize(input, out_shape, scale, name, 'TRILINEAR',
actual_shape, align_corners, align_mode)
@templatedoc(op_type="nearest_interp") @templatedoc(op_type="nearest_interp")
def resize_nearest(input, def resize_nearest(input,
out_shape=None, out_shape=None,
...@@ -8041,7 +8204,7 @@ def resize_nearest(input, ...@@ -8041,7 +8204,7 @@ def resize_nearest(input,
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
Args: Args:
input(${x_type}): ${x_comment}. input(${x_type}): input should be a 4-D tensor.
out_shape(list|tuple|Variable|None): Output shape of resize nearest out_shape(list|tuple|Variable|None): Output shape of resize nearest
layer, the shape is (out_h, out_w). layer, the shape is (out_h, out_w).
...@@ -8069,7 +8232,7 @@ def resize_nearest(input, ...@@ -8069,7 +8232,7 @@ def resize_nearest(input,
align_corners(bool): ${align_corners_comment} align_corners(bool): ${align_corners_comment}
Returns: Returns:
${out_comment}. A 4-D tensor in shape of (num_batches, channels, out_h, out_w)
Examples: Examples:
.. code-block:: python .. code-block:: python
......
...@@ -205,6 +205,17 @@ class TestBilinearInterpCase6(TestBilinearInterpOp): ...@@ -205,6 +205,17 @@ class TestBilinearInterpCase6(TestBilinearInterpOp):
self.align_mode = 1 self.align_mode = 1
class TestBilinearInterpSame(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 128, 64]
self.out_h = 128
self.out_w = 64
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpActualShape(TestBilinearInterpOp): class TestBilinearInterpActualShape(TestBilinearInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'bilinear' self.interp_method = 'bilinear'
......
...@@ -1295,16 +1295,74 @@ class TestBook(LayerTest): ...@@ -1295,16 +1295,74 @@ class TestBook(LayerTest):
x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32") x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_bilinear(x, out_shape=[12, 12]) output = layers.resize_bilinear(x, out_shape=[12, 12])
return (output) return (output)
output = layers.resize_bilinear(x, scale=3)
def make_resize_bilinear_by_scale(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_bilinear(x, scale=1.5)
return (output) return (output)
def make_resize_nearest(self): def make_resize_nearest(self):
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x1', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12])
except ValueError:
pass
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(
name='x2', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12, 12])
except ValueError:
pass
with program_guard(fluid.default_main_program(), with program_guard(fluid.default_main_program(),
fluid.default_startup_program()): fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32") x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12]) output = layers.resize_nearest(x, out_shape=[12, 12])
return (output) return (output)
output = layers.resize_nearest(x, scale=3)
def make_resize_nearest_by_scale(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x1', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, scale=1.8)
return (output)
def make_resize_trilinear(self):
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x2', shape=[3, 9, 6], dtype="float32")
output = layers.resize_trilinear(x, out_shape=[12, 12, 12])
except ValueError:
pass
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(
name='x', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_trilinear(x, out_shape=[12, 12])
except ValueError:
pass
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_trilinear(x, out_shape=[12, 12, 12])
return (output)
def make_resize_trilinear_by_scale(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_trilinear(x, scale=2.1)
return (output) return (output)
def make_polygon_box_transform(self): def make_polygon_box_transform(self):
......
...@@ -176,6 +176,16 @@ class TestNearestNeighborInterpCase6(TestNearestInterpOp): ...@@ -176,6 +176,16 @@ class TestNearestNeighborInterpCase6(TestNearestInterpOp):
self.align_corners = True self.align_corners = True
class TestNearestNeighborInterpSame(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 128, 64]
self.out_h = 128
self.out_w = 64
self.scale = 0.
self.align_corners = True
class TestNearestNeighborInterpActualShape(TestNearestInterpOp): class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'nearest' self.interp_method = 'nearest'
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
def trilinear_interp_np(input,
out_d,
out_h,
out_w,
out_size=None,
actual_shape=None,
align_corners=True,
align_mode=0):
"""trilinear interpolation implement in shape [N, C, D, H, W]"""
if out_size is not None:
out_d = out_size[0]
out_h = out_size[1]
out_w = out_size[2]
if actual_shape is not None:
out_d = actual_shape[0]
out_h = actual_shape[1]
out_w = actual_shape[2]
batch_size, channel, in_d, in_h, in_w = input.shape
ratio_d = ratio_h = ratio_w = 0.0
if out_d > 1:
if (align_corners):
ratio_d = (in_d - 1.0) / (out_d - 1.0)
else:
ratio_d = 1.0 * in_d / out_d
if out_h > 1:
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0)
else:
ratio_h = 1.0 * in_h / out_h
if out_w > 1:
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((batch_size, channel, out_d, out_h, out_w))
for i in range(out_d):
if (align_mode == 0 and not align_corners):
d = int(ratio_d * (i + 0.5) - 0.5)
else:
d = int(ratio_d * i)
d = max(0, d)
did = 1 if d < in_d - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_d = max(ratio_d * (i + 0.5) - 0.5, 0)
d1lambda = idx_src_d - d
else:
d1lambda = ratio_d * i - d
d2lambda = 1.0 - d1lambda
for j in range(out_h):
if (align_mode == 0 and not align_corners):
h = int(ratio_h * (j + 0.5) - 0.5)
else:
h = int(ratio_h * j)
h = max(0, h)
hid = 1 if h < in_h - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_h = max(ratio_h * (j + 0.5) - 0.5, 0)
h1lambda = idx_src_h - h
else:
h1lambda = ratio_h * j - h
h2lambda = 1.0 - h1lambda
for k in range(out_w):
if (align_mode == 0 and not align_corners):
w = int(ratio_w * (k + 0.5) - 0.5)
else:
w = int(ratio_w * k)
w = max(0, w)
wid = 1 if w < in_w - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_w = max(ratio_w * (k + 0.5) - 0.5, 0)
w1lambda = idx_src_w - w
else:
w1lambda = ratio_w * k - w
w2lambda = 1.0 - w1lambda
out[:, :, i, j, k] = \
d2lambda * \
(h2lambda * (w2lambda * input[:, :, d, h, w] + \
w1lambda * input[:, :, d, h, w+wid]) + \
h1lambda * (w2lambda * input[:, :, d, h+hid, w] + \
w1lambda * input[:, :, d, h+hid, w+wid])) + \
d1lambda * \
(h2lambda * (w2lambda * input[:, :, d+did, h, w] + \
w1lambda * input[:, :, d+did, h, w+wid]) + \
h1lambda * (w2lambda * input[:, :, d+did, h+hid, w] + \
w1lambda * input[:, :, d+did, h+hid, w+wid]))
return out.astype(input.dtype)
class TestTrilinearInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "trilinear_interp"
input_np = np.random.random(self.input_shape).astype("float32")
if self.scale > 0:
out_d = int(self.input_shape[2] * self.scale)
out_h = int(self.input_shape[3] * self.scale)
out_w = int(self.input_shape[4] * self.scale)
else:
out_d = self.out_d
out_h = self.out_h
out_w = self.out_w
output_np = trilinear_interp_np(input_np, out_d, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners, self.align_mode)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_d': self.out_d,
'out_h': self.out_h,
'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 4, 4, 4]
self.out_d = 2
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase1(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 1, 7, 8, 9]
self.out_d = 1
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase2(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 9, 6, 8]
self.out_d = 12
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase3(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 2, 16, 8, 4]
self.out_d = 32
self.out_h = 16
self.out_w = 8
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase4(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [4, 1, 7, 8, 9]
self.out_d = 1
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2, 2]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase5(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 3, 9, 6, 8]
self.out_d = 12
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = np.array([11, 11, 11]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase6(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 1, 16, 8, 4]
self.out_d = 8
self.out_h = 32
self.out_w = 16
self.scale = 0.
self.out_size = np.array([17, 9, 5]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpSame(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 1, 16, 8, 4]
self.out_d = 16
self.out_h = 8
self.out_w = 4
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpSameHW(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 1, 16, 8, 4]
self.out_d = 8
self.out_h = 8
self.out_w = 4
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpActualShape(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 2, 16, 8, 4]
self.out_d = 64
self.out_h = 32
self.out_w = 16
self.scale = 0.
self.out_size = np.array([33, 19, 7]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "trilinear_interp"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
if self.scale > 0:
out_d = int(self.input_shape[2] * self.scale)
out_h = int(self.input_shape[3] * self.scale)
out_w = int(self.input_shape[4] * self.scale)
else:
out_d = self.out_d
out_h = self.out_h
out_w = self.out_w
output_np = trilinear_interp_np(input_np, out_d, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners, self.align_mode)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {
'out_d': self.out_d,
'out_h': self.out_h,
'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 3, 9, 6, 8]
self.out_d = 13
self.out_h = 10
self.out_w = 9
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase1Uint8(TestTrilinearInterpOpUint8):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 16, 8, 4]
self.out_d = 13
self.out_h = 7
self.out_w = 2
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase2Uint8(TestTrilinearInterpOpUint8):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [4, 1, 7, 8, 9]
self.out_d = 3
self.out_h = 5
self.out_w = 13
self.scale = 0.
self.out_size = np.array([6, 15, 21]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpOtherMethod1(TestTrilinearInterpOp):
def set_align_mode(self):
self.align_corners = False
self.align_mode = 1
class TestTrilinearInterpWithMethod2(TestTrilinearInterpOp):
def set_align_mode(self):
self.align_corners = False
self.align_mode = 0
class TestTrilinearInterpWithMethod3(TestTrilinearInterpOp):
def set_align_mode(self):
self.align_corners = True
self.align_mode = 0
class TestTrilinearInterpScale1(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 9]
self.out_d = 82
self.out_h = 60
self.out_w = 25
self.scale = 2.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpScale2(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 9]
self.out_d = 82
self.out_h = 60
self.out_w = 25
self.scale = 1.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpScale3(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 9]
self.out_d = 82
self.out_h = 60
self.out_w = 25
self.scale = 1.5
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpZero(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 11]
self.out_d = 82
self.out_h = 60
self.out_w = 25
self.scale = 0.2
self.align_corners = False
self.align_mode = 0
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册