未验证 提交 d65a7a46 编写于 作者: C chentianyu03 提交者: GitHub

[Phi]Interploatd kernels into phi (#40855)

* add interploate cpu kernel

* fix nullptr bug

* add interpolate gpu kernel

* fix unit test error

* remove raw kernels

* add cuda kernel impl

* add infermeta

* recover accidentally deleted kernels in interpolate op

* fix grad x_grad name error

* remove interpolate_v2_op.h

* rm unused codes

* fix xpu build error

* fix build error

* fix namespace error

* add register header for nup

* fix infermeta error

* modify by review

* add the missing args in test_trt_convert_nearest_interp_v2
上级 597d7efd
...@@ -2167,7 +2167,11 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2167,7 +2167,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
typeid(paddle::optional<const phi::DenseTensor&>)) || typeid(paddle::optional<const phi::DenseTensor&>)) ||
input_defs[i].type_index == input_defs[i].type_index ==
std::type_index( std::type_index(
typeid(paddle::optional<const phi::SelectedRows&>)))) { typeid(paddle::optional<const phi::SelectedRows&>)) ||
input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<
const std::vector<const phi::DenseTensor*>>)))) {
pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr); pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1; auto end_idx = start_idx + 1;
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx),
...@@ -2429,6 +2433,10 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2429,6 +2433,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
std::type_index(typeid(std::vector<std::string>))) { std::type_index(typeid(std::vector<std::string>))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr_it->second)); BOOST_GET_CONST(std::vector<std::string>, attr_it->second));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<float>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr_it->second));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct " "Unsupported cast op attribute `%s` when construct "
......
...@@ -272,6 +272,14 @@ void BuildDygraphPhiKernelContext( ...@@ -272,6 +272,14 @@ void BuildDygraphPhiKernelContext(
auto end_idx = start_idx + 1; auto end_idx = start_idx + 1;
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
continue; continue;
} else if (input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<
const std::vector<const phi::DenseTensor*>>))) {
kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1;
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
continue;
} else { } else {
PADDLE_THROW(phi::errors::NotFound( PADDLE_THROW(phi::errors::NotFound(
"Can not find input variable '%s' for %s OP, please check whether " "Can not find input variable '%s' for %s OP, please check whether "
...@@ -545,6 +553,9 @@ void BuildDygraphPhiKernelContext( ...@@ -545,6 +553,9 @@ void BuildDygraphPhiKernelContext(
std::type_index(typeid(std::vector<std::string>))) { std::type_index(typeid(std::vector<std::string>))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr)); BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<float>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<float>, attr));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct " "Unsupported cast op attribute `%s` when construct "
......
...@@ -9,11 +9,15 @@ ...@@ -9,11 +9,15 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -722,64 +726,51 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateV2GradNoNeedBufferVarsInferer, ...@@ -722,64 +726,51 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateV2GradNoNeedBufferVarsInferer,
// not // not
// compatible with interp_op, so a new one is added in paddle2.0 // compatible with interp_op, so a new one is added in paddle2.0
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(bilinear_interp_v2, BilinearInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(nearest_interp_v2, NearestInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(trilinear_interp_v2,
TrilinearInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(bicubic_interp_v2, BicubicInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(linear_interp_v2, LinearInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
REGISTER_OPERATOR(bilinear_interp_v2, ops::InterpolateV2Op, REGISTER_OPERATOR(bilinear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker, ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>, ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>); ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
BilinearInterpInferShapeFunctor);
REGISTER_OPERATOR(bilinear_interp_v2_grad, ops::InterpolateV2OpGrad, REGISTER_OPERATOR(bilinear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer); ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(nearest_interp_v2, ops::InterpolateV2Op, REGISTER_OPERATOR(nearest_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker, ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>, ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>); ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
NearestInterpInferShapeFunctor);
REGISTER_OPERATOR(nearest_interp_v2_grad, ops::InterpolateV2OpGrad, REGISTER_OPERATOR(nearest_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer); ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(trilinear_interp_v2, ops::InterpolateV2Op, REGISTER_OPERATOR(trilinear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker, ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>, ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>); ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
TrilinearInterpInferShapeFunctor);
REGISTER_OPERATOR(trilinear_interp_v2_grad, ops::InterpolateV2OpGrad, REGISTER_OPERATOR(trilinear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer); ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(bicubic_interp_v2, ops::InterpolateV2Op, REGISTER_OPERATOR(bicubic_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker, ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>, ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>); ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
BicubicInterpInferShapeFunctor);
REGISTER_OPERATOR(bicubic_interp_v2_grad, ops::InterpolateV2OpGrad, REGISTER_OPERATOR(bicubic_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer); ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(bilinear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<int>,
ops::InterpolateV2Kernel<int64_t>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(nearest_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OPERATOR(linear_interp_v2, ops::InterpolateV2Op, REGISTER_OPERATOR(linear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker, ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>, ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>); ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
LinearInterpInferShapeFunctor);
REGISTER_OPERATOR(linear_interp_v2_grad, ops::InterpolateV2OpGrad, REGISTER_OPERATOR(linear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer); ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(linear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(linear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::FastDivMod;
using DataLayout = framework::DataLayout;
static inline int GetLastPow2(int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return std::max(1, n - (n >> 1));
}
inline platform::GpuLaunchConfig GetGpuLaunchConfig3D(
const platform::CUDADeviceContext& context, int num_img, int height,
int width) {
const int kThreadsPerBlock = 256;
int max_threads_per_block = context.GetMaxThreadsPerBlock(); // 1024
int max_threads = std::min(kThreadsPerBlock, max_threads_per_block);
int block_x = std::min(GetLastPow2(width), max_threads);
int block_y = std::min(GetLastPow2(height), max_threads / block_x);
int block_z = std::min(num_img, max_threads / block_x / block_y);
auto max_grid_dim = context.GetCUDAMaxGridDimSize();
int grid_x = std::min<int>(max_grid_dim[0], platform::DivUp(width, block_x));
int grid_y = std::min<int>(max_grid_dim[1], platform::DivUp(height, block_y));
int grid_z =
std::min<int>(max_grid_dim[2], platform::DivUp(num_img, block_z * 4));
const int capability = context.GetComputeCapability();
platform::GpuLaunchConfig config;
config.compute_capability = capability;
config.thread_per_block = dim3(block_x, block_y, block_z);
config.block_per_grid = dim3(grid_x, grid_y, grid_z);
return config;
}
template <typename T>
__forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
int* in_img_idx, int* x_id, T* lambda1, T* lambda2, T src_x,
const int in_img_x) {
src_x = (src_x > 0) ? src_x : 0.f;
*in_img_idx = static_cast<int>(src_x);
*x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0;
*lambda1 = src_x - *in_img_idx;
*lambda2 = 1.f - *lambda1;
}
struct FastDivModForInterpolate {
public:
FastDivMod channels_div;
FastDivMod output_w_div;
FastDivMod output_wc_div;
explicit HOSTDEVICE FastDivModForInterpolate(const int channels,
const int output_w,
const int outout_wc)
: channels_div(FastDivMod(channels)),
output_w_div(FastDivMod(output_w)),
output_wc_div(FastDivMod(outout_wc)) {}
};
template <typename T>
__global__ void KeNearestNeighborInterpNCHWFw(
const T* in, const size_t in_img_h, const size_t in_img_w, T* out,
const size_t out_img_h, const size_t out_img_w, const size_t nc,
const float ratio_h, const float ratio_w, const bool align_corners) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;
// nearest_sampling by multiple read in_addr and write to out_addr
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;
int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;
// prevent from multiple threads writing
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
out[out_index] = in[in_index];
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}
}
template <typename T>
__global__ void KeNearestNeighborInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_img_size = in_img_h * in_img_w;
int out_img_size = out_img_h * out_img_w;
for (; tid < nthreads; tid += stride) {
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
}
template <typename T>
__global__ void KeNearestNeighbor3DInterpFw(
const T* in, const size_t in_img_d, const size_t in_img_h,
const size_t in_img_w, const size_t input_h, const size_t input_w, T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const float ratio_d, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w; // ncdhw
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = (align_corners)
? static_cast<int>(ratio_d * out_img_idt + 0.5)
: static_cast<int>(ratio_d * out_img_idt);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
if (data_layout == DataLayout::kNCHW) {
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w +
in_img_idx];
} else {
out[tid] = in[out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
}
}
template <typename T>
__global__ void KeNearestNeighborInterpNCHWBw(
T* in, const size_t in_img_h, const size_t in_img_w, const T* out,
const size_t out_img_h, const size_t out_img_w, const size_t nc,
const float ratio_h, const float ratio_w, const bool align_corners) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;
// nearest_sampling by multiple read in_addr and write to out_addr
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;
int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;
// prevent from multiple threads writing
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
T* in_pos = &in[in_index];
const T out_pos = out[out_index];
platform::CudaAtomicAdd(in_pos, out_pos);
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}
}
template <typename T>
__global__ void KeNearestNeighborInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_img_size = in_img_h * in_img_w;
int out_img_size = out_img_h * out_img_w;
for (; tid < nthreads; tid += stride) {
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
T* in_pos = &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
const T out_pos = out[tid];
platform::CudaAtomicAdd(in_pos, out_pos);
}
}
template <typename T>
__global__ void KeNearestNeighbor3DInterpBw(
T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, const T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const float ratio_d, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = (align_corners)
? static_cast<int>(ratio_d * out_img_idt + 0.5)
: static_cast<int>(ratio_d * out_img_idt);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w +
in_img_idx];
} else {
in_pos = &in[out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
const T out_pos = out[out_id_h * output_w + out_id_w];
platform::CudaAtomicAdd(in_pos, out_pos);
}
}
template <typename T>
__global__ void KeLinearInterpFw(const T* in, const size_t in_img_w,
const size_t input_w, T* out,
const size_t out_img_w, const size_t output_h,
const size_t output_w,
const size_t num_channels, const float ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idx = tid % out_img_w;
} else {
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
const T* in_pos =
&in[out_id_h * out_id_w + channel_id * in_img_size + in_img_idx];
// linear interpolation
out[out_id_h * output_w + out_id_w] =
w2lambda * in_pos[0] + w1lambda * in_pos[w_id];
} else {
const T* in_pos =
&in[out_id_h * input_w + in_img_idx * num_channels + channel_id];
// linear interpolation
out[out_id_h * output_w + out_id_w] =
w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels];
}
}
}
template <typename T>
__global__ void KeLinearInterpBw(T* in, const size_t in_img_w,
const size_t input_w, const T* out,
const size_t out_img_w, const size_t output_h,
const size_t output_w,
const size_t num_channels, const T ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idx = tid % out_img_w;
} else {
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5
: ratio_w * out_img_idx;
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size + in_img_idx];
} else {
in_pos = &in[out_id_h * input_w + in_img_idx * num_channels + channel_id];
}
const T* out_pos = &out[out_id_w];
if (data_layout == DataLayout::kNCHW) {
platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id], w1lambda * out_pos[0]);
} else {
platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
w1lambda * out_pos[0]);
}
}
}
template <typename T>
__global__ void KeBilinearInterpNCHWFw(const T* in, const size_t in_img_h,
const size_t in_img_w, T* out,
const size_t out_img_h,
const size_t out_img_w, const size_t nc,
const float ratio_h, const float ratio_w,
const T align_type_value) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;
int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
&w2lambda, src_w, in_img_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_img_h);
int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;
int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;
// prevent from multiple threads writing
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
const T* in_pos = &in[in_index];
out[out_index] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}
}
template <typename T>
__global__ void KeBilinearInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const T align_type_value, FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
&w2lambda, src_w, in_img_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_img_h);
// bilinear interpolation
const T* in_pos =
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
out[tid] =
h2lambda *
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
h1lambda *
(w2lambda * in_pos[h_id * in_img_w * num_channels] +
w1lambda *
in_pos[h_id * in_img_w * num_channels + w_id * num_channels]);
}
}
/* Calculate the minimum of partial elements in a block */
template <typename T>
__inline__ __device__ T PartialBlockMin(T val, size_t threads_num_in_block,
unsigned mask) {
__shared__ T shared[WARP_SIZE];
__shared__ T shared_last_val;
__shared__ int shared_last_idx;
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
int threshold = (threads_num_in_block & (-WARP_SIZE));
if (threadIdx.x < threshold) {
shared_last_idx = (threshold >> 5) - 1;
val = phi::funcs::warpReduceMin(val, mask);
if (lane == 0) {
shared[wid] = val;
}
} else {
shared_last_val = std::numeric_limits<T>::max();
platform::CudaAtomicMin(&shared_last_val, val);
shared[wid] = shared_last_val;
shared_last_idx = wid;
}
__syncthreads();
if (threadIdx.x < threshold) {
val = (lane <= shared_last_idx) ? shared[lane]
: std::numeric_limits<T>::max();
val = phi::funcs::warpReduceMin(val, mask);
shared_last_val = val;
}
__syncthreads();
if (threadIdx.x >= threshold) {
val = shared_last_val;
}
return val;
}
template <typename T>
__global__ void KeBilinearInterpBwShareMemory(
T* in, const int in_h, const int in_w, const T* __restrict__ out,
const int out_h, const int out_w, const int n, const int num_channels,
float ratio_h, float ratio_w, const T align_type_value, bool is_nchw) {
__shared__ T s_data[2][1024];
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_chw = in_h * in_w * num_channels;
int out_chw = num_channels * out_h * out_w;
int nthreads = n * out_chw;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / out_chw;
int out_id_w = tid % out_chw;
const int in_img_size = in_h * in_w;
const int out_img_size = out_h * out_w;
T value = out[out_id_h * out_chw + out_id_w];
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_w;
int out_img_idx = tid % out_w;
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
&w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
// top_left_index is just input_index.
int input_index = out_id_h * in_chw + channel_id * in_img_size +
in_img_idy * in_w + in_img_idx;
int top_right_index = input_index + w_id;
int bot_left_index = input_index + h_id * in_w;
int bot_right_index = input_index + h_id * in_w + w_id;
int in_top_min_index, in_bot_min_index;
s_data[0][threadIdx.x] = 0.f;
s_data[1][threadIdx.x] = 0.f;
int remain = nthreads - (tid & (-blockDim.x));
int in_top_max_index =
phi::funcs::blockReduceMax(top_right_index, FINAL_MASK);
int in_bot_max_index =
phi::funcs::blockReduceMax(bot_right_index, FINAL_MASK);
if (remain > blockDim.x) {
in_top_min_index = phi::funcs::blockReduceMin(input_index, FINAL_MASK);
in_bot_min_index = phi::funcs::blockReduceMin(bot_left_index, FINAL_MASK);
} else {
in_top_min_index = PartialBlockMin(input_index, remain, FINAL_MASK);
in_bot_min_index = PartialBlockMin(bot_left_index, remain, FINAL_MASK);
}
int upper_limit_share_idx = (in_top_max_index - in_top_min_index) >
(in_bot_max_index - in_bot_min_index)
? (in_top_max_index - in_top_min_index)
: (in_bot_max_index - in_bot_min_index);
if (h_id != 0) {
platform::CudaAtomicAdd(&s_data[0][input_index - in_top_min_index],
h2lambda * w2lambda * value);
platform::CudaAtomicAdd(&s_data[0][top_right_index - in_top_min_index],
h2lambda * w1lambda * value);
platform::CudaAtomicAdd(&s_data[1][bot_left_index - in_bot_min_index],
h1lambda * w2lambda * value);
platform::CudaAtomicAdd(&s_data[1][bot_right_index - in_bot_min_index],
h1lambda * w1lambda * value);
} else {
platform::CudaAtomicAdd(&s_data[0][top_right_index - in_top_min_index],
(h2lambda + h1lambda) * w1lambda * value);
platform::CudaAtomicAdd(&s_data[1][bot_left_index - in_bot_min_index],
(h1lambda + h2lambda) * w2lambda * value);
}
__syncthreads();
if (threadIdx.x <= upper_limit_share_idx) {
platform::CudaAtomicAdd(&in[in_top_min_index + threadIdx.x],
s_data[0][threadIdx.x]);
platform::CudaAtomicAdd(&in[in_bot_min_index + threadIdx.x],
s_data[1][threadIdx.x]);
}
}
}
__device__ __forceinline__ int GetInputIndex(const size_t nc, const int height,
const int width, const int h,
const int w) {
return (nc * height + h) * width + w;
}
template <typename T>
__global__ void KeBilinearInterpNCHWBw(T* in, const int in_h, const int in_w,
const int out_h, const int out_w,
const int n, const int num_channels,
float ratio_h, float ratio_w,
const T* __restrict__ out,
const T align_type_value) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
int stride = blockDim.x * gridDim.x;
int num_out = n * num_channels * out_h * out_w;
int num_in = n * num_channels * in_h * in_w;
for (; index < num_out; index += stride) {
int index_tmp = index;
int w2 = index_tmp % out_w;
index_tmp /= out_w;
int h2 = index_tmp % out_h;
int nc = index_tmp / out_h;
int h1, y_id;
T h1lambda, h0lambda;
T src_y = ratio_h * (h2 + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&h1, &y_id, &h1lambda, &h0lambda,
src_y, in_h);
int w1, x_id;
T w1lambda, w0lambda;
T src_x = ratio_w * (w2 + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&w1, &x_id, &w1lambda, &w0lambda,
src_x, in_w);
T d2val = out[index];
platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1),
h0lambda * w0lambda * d2val);
platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id),
h0lambda * w1lambda * d2val);
platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1),
h1lambda * w0lambda * d2val);
platform::CudaAtomicAdd(
in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id),
h1lambda * w1lambda * d2val);
}
}
template <typename T>
__global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
const T* __restrict__ out, const int out_h,
const int out_w, const int n,
const int out_chw, const int num_channels,
float ratio_h, float ratio_w,
const T align_type_value,
FastDivModForInterpolate divmods) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_chw = in_h * in_w * num_channels;
int nthreads = n * out_chw;
for (; tid < nthreads; tid += stride) {
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
&w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
T value = out[tid];
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
in_img_idx * num_channels + channel_id];
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
h2lambda * w1lambda * value);
platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels],
h1lambda * w2lambda * value);
platform::CudaAtomicAdd(
&in_pos[h_id * in_w * num_channels + w_id * num_channels],
h1lambda * w1lambda * value);
}
}
template <typename T>
__global__ void KeTrilinearInterpFw(
const T* in, const size_t in_img_d, const size_t in_img_h,
const size_t in_img_w, const size_t input_h, const size_t input_w, T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const float ratio_d, const float ratio_h, const float ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = align_flag
? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
: static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
(in_img_idt * in_img_h + in_img_idy) * in_img_w +
in_img_idx;
const T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
const T* in_pos2 = &in[in_pos2_idx];
// trilinear interpolation
out[out_id_h * output_w + out_id_w] =
d2lambda *
(h2lambda * (w2lambda * in_pos1[0] + w1lambda * in_pos1[w_id]) +
h1lambda * (w2lambda * in_pos1[h_id * in_img_w] +
w1lambda * in_pos1[h_id * in_img_w + w_id])) +
d1lambda *
(h2lambda * (w2lambda * in_pos2[0] + w1lambda * in_pos2[w_id]) +
h1lambda * (w2lambda * in_pos2[h_id * in_img_w] +
w1lambda * in_pos2[h_id * in_img_w + w_id]));
} else {
int in_pos1_idx = out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id;
const T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w * num_channels;
const T* in_pos2 = &in[in_pos2_idx];
// trilinear interpolation
out[out_id_h * output_w + out_id_w] =
d2lambda *
(h2lambda * (w2lambda * in_pos1[0] +
w1lambda * in_pos1[w_id * num_channels]) +
h1lambda * (w2lambda * in_pos1[h_id * in_img_w * num_channels] +
w1lambda * in_pos1[h_id * in_img_w * num_channels +
w_id * num_channels])) +
d1lambda *
(h2lambda * (w2lambda * in_pos2[0] +
w1lambda * in_pos2[w_id * num_channels]) +
h1lambda * (w2lambda * in_pos2[h_id * in_img_w * num_channels] +
w1lambda * in_pos2[h_id * in_img_w * num_channels +
w_id * num_channels]));
}
}
}
template <typename T>
__global__ void KeTrilinearInterpBw(
T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, const T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const T ratio_d, const T ratio_h, const T ratio_w, const bool align_corners,
const int align_mode, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = align_flag
? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
: static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
(in_img_idt * in_img_h + in_img_idy) * in_img_w +
in_img_idx;
T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
T* in_pos2 = &in[in_pos2_idx];
const T* out_pos = &out[out_id_h * output_w + out_id_w];
// trilinear interpolation grad
platform::CudaAtomicAdd(&in_pos1[0],
d2lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[w_id],
d2lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w],
d2lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w + w_id],
d2lambda * h1lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[0],
d1lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[w_id],
d1lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w],
d1lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w + w_id],
d1lambda * h1lambda * w1lambda * out_pos[0]);
} else {
int in_pos1_idx = out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id;
T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w * num_channels;
T* in_pos2 = &in[in_pos2_idx];
const T* out_pos = &out[out_id_h * output_w + out_id_w];
// trilinear interpolation grad
platform::CudaAtomicAdd(&in_pos1[0],
d2lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[w_id * num_channels],
d2lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w * num_channels],
d2lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(
&in_pos1[h_id * in_img_w * num_channels + w_id * num_channels],
d2lambda * h1lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[0],
d1lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[w_id * num_channels],
d1lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w * num_channels],
d1lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(
&in_pos2[h_id * in_img_w * num_channels + w_id * num_channels],
d1lambda * h1lambda * w1lambda * out_pos[0]);
}
}
}
template <typename T>
__device__ __forceinline__ static T Kecubic_interp(const T x0, const T x1,
const T x2, const T x3,
T t) {
T coeffs[4];
T a = -0.75;
T x_1 = t;
T x_2 = 1.0 - t;
coeffs[0] = cubic_convolution2<T>(x_1 + 1.0, a);
coeffs[1] = cubic_convolution1<T>(x_1, a);
coeffs[2] = cubic_convolution1<T>(x_2, a);
coeffs[3] = cubic_convolution2<T>(x_2 + 1.0, a);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename T>
__global__ void KeBicubicInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
T in_img_idy = align_corners
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y;
T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx);
const T x_t = in_img_idx - input_x;
T coefficients[4];
const T* in_pos_0;
const T* in_pos_1;
const T* in_pos_2;
const T* in_pos_3;
int access_x_0;
if (data_layout == DataLayout::kNCHW) {
for (int k = 0; k < 4; k++) {
int access_y =
max(min(input_y - 1 + k, static_cast<int>(in_img_h - 1)), 0);
access_x_0 = max(min(input_x - 1, static_cast<int>(in_img_w - 1)), 0);
int access_x_1 =
max(min(input_x + 0, static_cast<int>(in_img_w - 1)), 0);
int access_x_2 =
max(min(input_x + 1, static_cast<int>(in_img_w - 1)), 0);
int access_x_3 =
max(min(input_x + 2, static_cast<int>(in_img_w - 1)), 0);
in_pos_0 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_0];
in_pos_1 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_1];
in_pos_2 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_2];
in_pos_3 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_3];
coefficients[k] = Kecubic_interp<T>(in_pos_0[0], in_pos_1[0],
in_pos_2[0], in_pos_3[0], x_t);
}
out[out_id_h * output_w + out_id_w] =
Kecubic_interp<T>(coefficients[0], coefficients[1], coefficients[2],
coefficients[3], y_t);
} else {
for (int k = 0; k < 4; k++) {
int access_y =
max(min(input_y - 1 + k, static_cast<int>((in_img_h - 1))), 0);
int access_x_0 =
max(min(input_x - 1, static_cast<int>((in_img_w - 1))), 0);
int access_x_1 =
max(min(input_x + 0, static_cast<int>((in_img_w - 1))), 0);
int access_x_2 =
max(min(input_x + 1, static_cast<int>((in_img_w - 1))), 0);
int access_x_3 =
max(min(input_x + 2, static_cast<int>((in_img_w - 1))), 0);
const T* in_pos_0 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_0 * num_channels + channel_id];
const T* in_pos_1 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_1 * num_channels + channel_id];
const T* in_pos_2 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_2 * num_channels + channel_id];
const T* in_pos_3 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_3 * num_channels + channel_id];
coefficients[k] = Kecubic_interp(in_pos_0[0], in_pos_1[0], in_pos_2[0],
in_pos_3[0], x_t);
}
out[out_id_h * output_w + out_id_w] =
static_cast<T>(Kecubic_interp(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t));
}
}
}
template <typename T>
__global__ void KeBicubicInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
T in_img_idy = align_corners
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y;
T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx);
const T x_t = in_img_idx - input_x;
T x_coeffs[4];
T y_coeffs[4];
get_cubic_upsample_coefficients(x_coeffs, x_t);
get_cubic_upsample_coefficients(y_coeffs, y_t);
const T* out_pos = &out[out_id_h * output_w + out_id_w];
T* in_pos;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
int access_y = max(min(static_cast<int>(input_y - 1 + j),
static_cast<int>(in_img_h - 1)),
0);
int access_x = max(min(static_cast<int>(input_x - 1 + i),
static_cast<int>(in_img_w - 1)),
0);
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x];
} else {
in_pos = &in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x * num_channels + channel_id];
}
platform::CudaAtomicAdd(&in_pos[0],
(out_pos[0] * y_coeffs[j] * x_coeffs[i]));
}
}
}
}
template <typename T>
static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
float scale_w = -1;
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_w = new_size[0];
} else {
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
} else {
if (scale.size() > 0) {
scale_w = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
}
}
if (scale_w > 0.) {
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_w = size_data[0];
}
}
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_w};
} else {
dim_out = {n, out_w, c};
}
auto output_data = output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1.0) / (out_w - 1.0)
: static_cast<float>(new_scale_w);
}
int64_t in_cw = c * in_w;
int64_t out_cw = c * out_w;
auto pixelNum = n * out_cw;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("linear" == interp_method) {
KeLinearInterpFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_w, in_cw, output_data, out_w, n, out_cw, c, ratio_w,
align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
float scale_w = -1;
float scale_h = -1;
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_h = new_size[0];
out_w = new_size[1];
} else {
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
} else {
if (scale.size() > 1) {
scale_w = scale[1];
scale_h = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
}
}
if (scale_w > 0. && scale_h > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
}
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_h, out_w};
} else {
dim_out = {n, out_h, out_w, c};
}
auto output_data = output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
int64_t in_hw = in_h * in_w;
int64_t out_hw = out_h * out_w;
int64_t in_chw = c * in_hw;
int64_t out_chw = c * out_hw;
auto pixelNum = n * out_chw;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("nearest" == interp_method) {
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
int nc = n * c;
platform::GpuLaunchConfig config_3d =
GetGpuLaunchConfig3D(ctx.cuda_device_context(), nc, out_h, out_w);
KeNearestNeighborInterpNCHWFw<
T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, output_data, out_h, out_w, nc, ratio_h,
ratio_w, align_corners);
} else {
int64_t cw = c * out_w;
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
KeNearestNeighborInterpFw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, interp_divmods);
}
} else if ("bilinear" == interp_method) {
dim3 thread_num = config.thread_per_block;
#ifdef WITH_NV_JETSON
if (config.compute_capability == 53 || config.compute_capability == 62) {
thread_num = 512;
}
#endif
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
int nc = n * c;
platform::GpuLaunchConfig config_3d =
GetGpuLaunchConfig3D(ctx.cuda_device_context(), nc, out_h, out_w);
KeBilinearInterpNCHWFw<
T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, output_data, out_h, out_w, nc, ratio_h,
ratio_w, align_type_value);
} else {
int64_t cw = c * out_w;
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_type_value, interp_divmods);
}
} else if ("bicubic" == interp_method) {
#ifdef __HIPCC__
constexpr int thread_per_block = 256;
#else
constexpr int thread_per_block = 512;
#endif
KeBicubicInterpFw<T><<<config.block_per_grid, thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
}
template <typename T>
static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
float scale_w = -1;
float scale_d = -1;
float scale_h = -1;
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
} else {
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_d = scale_data[0];
scale_h = scale_data[1];
scale_w = scale_data[2];
} else {
scale_d = scale_data[0];
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0, true,
platform::errors::InvalidArgument(
"The scale_d in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
} else {
if (scale.size() > 1) {
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0, true,
platform::errors::InvalidArgument(
"The scale_d in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
}
}
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
out_d = static_cast<int>(in_d * scale_d);
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
}
PADDLE_ENFORCE_GT(out_d, 0, platform::errors::InvalidArgument(
"out_d in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_d, out_h, out_w};
} else {
dim_out = {n, out_d, out_h, out_w, c};
}
auto output_data = output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
float new_scale_d = 0.f;
new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
: static_cast<float>(in_d) / out_d;
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(new_scale_d);
}
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
int64_t in_dhw = in_d * in_h * in_w;
int64_t out_dhw = out_d * out_h * out_w;
int64_t in_cdhw = c * in_dhw;
int64_t out_cdhw = c * out_dhw;
auto pixelNum = n * out_cdhw;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("trilinear" == interp_method) {
KeTrilinearInterpFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode, data_layout);
} else if ("nearest" == interp_method) {
KeNearestNeighbor3DInterpFw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
data_layout);
}
}
template <typename T>
static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
} else {
if (scale.size() > 0) {
scale_w = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
}
}
if (scale_w > 0.) {
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_w = size_data[0];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_w = new_size[0];
}
auto* output_grad_data = output_grad.data<T>();
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_w};
} else {
dim_grad = {n, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto* input_grad_data = input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
phi::funcs::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
int64_t in_cw = c * in_w;
int64_t out_cw = c * out_w;
auto pixelNum = n * out_cw;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("linear" == interp_method) {
KeLinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_w, in_cw, output_grad_data, out_w, n, out_cw, c,
ratio_w, align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
} else {
if (scale.size() > 1) {
scale_w = scale[1];
scale_h = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
}
}
if (scale_w > 0. && scale_h > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_h = new_size[0];
out_w = new_size[1];
}
auto* output_grad_data = output_grad.data<T>();
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_h, in_w};
} else {
dim_grad = {n, in_h, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto* input_grad_data = input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
phi::funcs::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
int64_t in_hw = in_h * in_w;
int64_t out_hw = out_h * out_w;
int64_t in_chw = c * in_hw;
int64_t out_chw = c * out_hw;
auto pixelNum = n * out_chw;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("nearest" == interp_method) {
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
int nc = n * c;
platform::GpuLaunchConfig config_3d =
GetGpuLaunchConfig3D(ctx.cuda_device_context(), nc, out_h, out_w);
KeNearestNeighborInterpNCHWBw<
T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, nc,
ratio_h, ratio_w, align_corners);
} else {
int64_t cw = c * out_w;
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
KeNearestNeighborInterpBw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w, align_corners,
interp_divmods);
}
} else if ("bilinear" == interp_method) {
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false;
bool optimize_flag = false;
#ifndef __HIPCC__
optimize_flag = (in_h < (out_h >> 6) && in_w < (out_w >> 6))
? true
: ((in_h == 1 && in_w == 1) ? true : false);
#endif
if (optimize_flag & is_nchw) {
KeBilinearInterpBwShareMemory<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c,
ratio_h, ratio_w, align_type_value, is_nchw);
} else if (!optimize_flag & is_nchw) {
//
const int num_kernels = n * c * out_h * out_w;
const int num_threads =
std::min(ctx.cuda_device_context().GetMaxThreadsPerBlock(), 1024);
KeBilinearInterpNCHWBw<
T><<<platform::DivUp(num_kernels, num_threads), num_threads, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, out_h, out_w, n, c, ratio_h, ratio_w,
output_grad_data, align_type_value);
} else {
int64_t cw = c * out_w;
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
KeBilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_type_value, interp_divmods);
}
} else if ("bicubic" == interp_method) {
#ifdef __HIPCC__
constexpr int thread_per_block = 256;
#else
constexpr int thread_per_block = 512;
#endif
KeBicubicInterpBw<T><<<config.block_per_grid, thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
}
template <typename T>
static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad,
const Tensor& output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_d = scale_data[0];
scale_h = scale_data[1];
scale_w = scale_data[2];
} else {
scale_d = scale_data[0];
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0, true,
platform::errors::InvalidArgument(
"The scale_d in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
} else {
if (scale.size() > 1) {
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0, true,
platform::errors::InvalidArgument(
"The scale_d in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
}
}
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
out_d = static_cast<int>(in_d * scale_d);
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
}
auto* output_grad_data = output_grad.data<T>();
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_d, in_h, in_w};
} else {
dim_grad = {n, in_d, in_h, in_w, c};
}
auto* input_grad_data = input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
phi::funcs::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
float new_scale_d = 0.f;
new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
: static_cast<float>(in_d) / out_d;
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(new_scale_d);
}
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
int64_t in_dhw = in_d * in_h * in_w;
int64_t out_dhw = out_d * out_h * out_w;
int64_t in_cdhw = c * in_dhw;
int64_t out_cdhw = c * out_dhw;
auto pixelNum = n * out_cdhw;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("trilinear" == interp_method) {
KeTrilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode, data_layout);
} else if ("nearest" == interp_method) {
KeNearestNeighbor3DInterpBw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
data_layout);
}
}
template <typename T>
class InterpolateOpV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::NotFound("This kernel only runs on GPU device."));
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
if (input_dims.size() == 3) { // 1D interpolation
Interpolate1DCUDAFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 4) { // 2D interpolation
Interpolate2DCUDAFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDAFwd<T>(ctx, *input, output);
}
}
};
template <typename T>
class InterpolateV2GradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::NotFound("This kernel only runs on GPU device."));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto output_grad_dims = output_grad->dims();
if (output_grad_dims.size() == 3) { // 1D interpolation
Interpolate1DCUDABwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 4) { // 2D interpolation
Interpolate2DCUDABwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDABwd<T>(ctx, input_grad, *output_grad);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(bilinear_interp_v2,
ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(bilinear_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
ops::InterpolateV2GradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(nearest_interp_v2,
ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int64_t>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(nearest_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
ops::InterpolateV2GradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(trilinear_interp_v2,
ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(trilinear_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
ops::InterpolateV2GradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(linear_interp_v2, ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(linear_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
ops::InterpolateV2GradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(bicubic_interp_v2,
ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(bicubic_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
ops::InterpolateV2GradOpCUDAKernel<double>);
...@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -401,7 +403,8 @@ class InterpolateV2NPUKernel : public framework::OpKernel<T> { ...@@ -401,7 +403,8 @@ class InterpolateV2NPUKernel : public framework::OpKernel<T> {
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w; int n, c, in_d, in_h, in_w;
ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h, &in_w); phi::funcs::ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h,
&in_w);
auto interp_method = ctx.Attr<std::string>("interp_method"); auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners"); bool align_corners = ctx.Attr<bool>("align_corners");
...@@ -431,14 +434,15 @@ class InterpolateV2NPUKernel : public framework::OpKernel<T> { ...@@ -431,14 +434,15 @@ class InterpolateV2NPUKernel : public framework::OpKernel<T> {
out_w = output_w[0]; out_w = output_w[0];
} else if (ctx.HasInput("OutSize")) { } else if (ctx.HasInput("OutSize")) {
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
auto out_size_data = get_new_data_from_tensor<int>(out_size); auto out_size_data = phi::funcs::get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0]; out_h = out_size_data[0];
out_w = out_size_data[1]; out_w = out_size_data[1];
} else { } else {
auto scale_tensor = ctx.Input<Tensor>("Scale"); auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale"); auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) { if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor); auto scale_data =
phi::funcs::get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) { if (scale_data.size() > 1) {
scale_h = scale_data[0]; scale_h = scale_data[0];
scale_w = scale_data[1]; scale_w = scale_data[1];
...@@ -538,7 +542,8 @@ class InterpolateV2NPUGradKernel : public framework::OpKernel<T> { ...@@ -538,7 +542,8 @@ class InterpolateV2NPUGradKernel : public framework::OpKernel<T> {
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w; int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); phi::funcs::ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h,
&in_w);
auto interp_method = ctx.Attr<std::string>("interp_method"); auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners"); bool align_corners = ctx.Attr<bool>("align_corners");
...@@ -567,14 +572,15 @@ class InterpolateV2NPUGradKernel : public framework::OpKernel<T> { ...@@ -567,14 +572,15 @@ class InterpolateV2NPUGradKernel : public framework::OpKernel<T> {
out_w = output_w[0]; out_w = output_w[0];
} else if (ctx.HasInput("OutSize")) { } else if (ctx.HasInput("OutSize")) {
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
auto out_size_data = get_new_data_from_tensor<int>(out_size); auto out_size_data = phi::funcs::get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0]; out_h = out_size_data[0];
out_w = out_size_data[1]; out_w = out_size_data[1];
} else { } else {
auto scale_tensor = ctx.Input<Tensor>("Scale"); auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale"); auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) { if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor); auto scale_data =
phi::funcs::get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) { if (scale_data.size() > 1) {
scale_h = scale_data[0]; scale_h = scale_data[0];
scale_w = scale_data[1]; scale_w = scale_data[1];
......
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/interpolate_v2_op.h" #include "paddle/phi/kernels/funcs/interpolate_function.h"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
namespace paddle { namespace paddle {
...@@ -57,7 +56,8 @@ class InterpolateV2XPUKernel : public framework::OpKernel<T> { ...@@ -57,7 +56,8 @@ class InterpolateV2XPUKernel : public framework::OpKernel<T> {
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w; int n, c, in_d, in_h, in_w;
ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h, &in_w); phi::funcs::ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h,
&in_w);
auto interp_method = ctx.Attr<std::string>("interp_method"); auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners"); bool align_corners = ctx.Attr<bool>("align_corners");
...@@ -78,7 +78,8 @@ class InterpolateV2XPUKernel : public framework::OpKernel<T> { ...@@ -78,7 +78,8 @@ class InterpolateV2XPUKernel : public framework::OpKernel<T> {
auto scale_tensor = ctx.Input<Tensor>("Scale"); auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale"); auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) { if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor); auto scale_data =
phi::funcs::get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) { if (scale_data.size() > 1) {
scale_h = scale_data[0]; scale_h = scale_data[0];
scale_w = scale_data[1]; scale_w = scale_data[1];
...@@ -107,7 +108,8 @@ class InterpolateV2XPUKernel : public framework::OpKernel<T> { ...@@ -107,7 +108,8 @@ class InterpolateV2XPUKernel : public framework::OpKernel<T> {
} }
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) { if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size); auto out_size_data =
phi::funcs::get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0]; out_h = out_size_data[0];
out_w = out_size_data[1]; out_w = out_size_data[1];
} }
...@@ -169,7 +171,8 @@ class InterpolateV2GradXPUKernel : public framework::OpKernel<T> { ...@@ -169,7 +171,8 @@ class InterpolateV2GradXPUKernel : public framework::OpKernel<T> {
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w; int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); phi::funcs::ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h,
&in_w);
auto interp_method = ctx.Attr<std::string>("interp_method"); auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners"); bool align_corners = ctx.Attr<bool>("align_corners");
...@@ -190,7 +193,8 @@ class InterpolateV2GradXPUKernel : public framework::OpKernel<T> { ...@@ -190,7 +193,8 @@ class InterpolateV2GradXPUKernel : public framework::OpKernel<T> {
auto scale_tensor = ctx.Input<Tensor>("Scale"); auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale"); auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) { if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor); auto scale_data =
phi::funcs::get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) { if (scale_data.size() > 1) {
scale_h = scale_data[0]; scale_h = scale_data[0];
scale_w = scale_data[1]; scale_w = scale_data[1];
...@@ -219,7 +223,8 @@ class InterpolateV2GradXPUKernel : public framework::OpKernel<T> { ...@@ -219,7 +223,8 @@ class InterpolateV2GradXPUKernel : public framework::OpKernel<T> {
} }
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) { if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size); auto out_size_data =
phi::funcs::get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0]; out_h = out_size_data[0];
out_w = out_size_data[1]; out_w = out_size_data[1];
} }
......
...@@ -179,6 +179,43 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, ...@@ -179,6 +179,43 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context,
return config; return config;
} }
static inline int GetLastPow2(int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return std::max(1, n - (n >> 1));
}
inline GpuLaunchConfig GetGpuLaunchConfig3D(const phi::GPUContext& context,
int num_img,
int height,
int width) {
const int kThreadsPerBlock = 256;
int max_threads_per_block = context.GetMaxThreadsPerBlock(); // 1024
int max_threads = std::min(kThreadsPerBlock, max_threads_per_block);
int block_x = std::min(GetLastPow2(width), max_threads);
int block_y = std::min(GetLastPow2(height), max_threads / block_x);
int block_z = std::min(num_img, max_threads / block_x / block_y);
auto max_grid_dim = context.GetCUDAMaxGridDimSize();
int grid_x =
std::min<int>(max_grid_dim[0], backends::gpu::DivUp(width, block_x));
int grid_y =
std::min<int>(max_grid_dim[1], backends::gpu::DivUp(height, block_y));
int grid_z = std::min<int>(max_grid_dim[2],
backends::gpu::DivUp(num_img, block_z * 4));
const int capability = context.GetComputeCapability();
GpuLaunchConfig config;
config.compute_capability = capability;
config.thread_per_block = dim3(block_x, block_y, block_z);
config.block_per_grid = dim3(grid_x, grid_y, grid_z);
return config;
}
} // namespace gpu } // namespace gpu
} // namespace backends } // namespace backends
} // namespace phi } // namespace phi
......
...@@ -87,6 +87,23 @@ std::vector<MetaTensor*> InferMetaContext::InputsBetween(size_t start, ...@@ -87,6 +87,23 @@ std::vector<MetaTensor*> InferMetaContext::InputsBetween(size_t start,
return result; return result;
} }
paddle::optional<const std::vector<const MetaTensor*>>
InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
const auto& first = inputs_.at(start);
if (first) {
std::vector<const MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.push_back(inputs_.at(i).get());
}
return paddle::optional<const std::vector<const MetaTensor*>>(result);
}
return paddle::optional<const std::vector<const MetaTensor*>>(paddle::none);
}
MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get(); return outputs_.at(idx).get();
} }
......
...@@ -54,6 +54,8 @@ class InferMetaContext { ...@@ -54,6 +54,8 @@ class InferMetaContext {
const MetaTensor& InputAt(size_t idx) const; const MetaTensor& InputAt(size_t idx) const;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(size_t idx) const; paddle::optional<const phi::MetaTensor&> OptionalInputAt(size_t idx) const;
std::vector<MetaTensor*> InputsBetween(size_t start, size_t end) const; std::vector<MetaTensor*> InputsBetween(size_t start, size_t end) const;
paddle::optional<const std::vector<const phi::MetaTensor*>>
OptionalInputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx); MetaTensor* MutableOutputAt(size_t idx);
std::vector<MetaTensor*> MutableOutputBetween(size_t start, size_t end); std::vector<MetaTensor*> MutableOutputBetween(size_t start, size_t end);
...@@ -174,6 +176,26 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -174,6 +176,26 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
} }
}; };
template <typename... Tail>
struct InferMetaFnCallHelper<
paddle::optional<const std::vector<const MetaTensor*>>,
Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"InferMeta's Input should appear before Attributes.");
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
paddle::optional<const std::vector<const MetaTensor*>> arg =
ctx->OptionalInputsBetween(range.first, range.second);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
// TODO(chenweihang): support other attr type later // TODO(chenweihang): support other attr type later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
......
...@@ -97,6 +97,22 @@ class KernelContext { ...@@ -97,6 +97,22 @@ class KernelContext {
return v; return v;
} }
template <typename TensorType>
paddle::optional<const std::vector<const TensorType*>> OptionalInputsBetween(
size_t start, size_t end) {
const auto& first = inputs_.at(start);
if (first) {
std::vector<const TensorType*> v;
for (size_t i = start; i < end; ++i) {
auto* t = static_cast<const TensorType*>(inputs_.at(i));
v.emplace_back(t);
}
return paddle::optional<const std::vector<const TensorType*>>(v);
}
return paddle::optional<const std::vector<const TensorType*>>(paddle::none);
}
template <typename TensorType> template <typename TensorType>
TensorType* MutableOutputAt(size_t idx) { TensorType* MutableOutputAt(size_t idx) {
return static_cast<TensorType*>(outputs_.at(idx)); return static_cast<TensorType*>(outputs_.at(idx));
......
...@@ -81,6 +81,13 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -81,6 +81,13 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
} else if (arg_type == std::type_index(typeid(
paddle::optional<
const std::vector<const DenseTensor*>>))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid( } else if (arg_type == std::type_index(typeid(
paddle::optional<const SelectedRows&>))) { paddle::optional<const SelectedRows&>))) {
args_def->AppendInput(default_key.backend(), args_def->AppendInput(default_key.backend(),
......
...@@ -126,6 +126,30 @@ namespace phi { ...@@ -126,6 +126,30 @@ namespace phi {
} \ } \
} }
#define PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_MULTI_INPUT(tensor_type) \
template <typename... Tail> \
struct KernelCallHelper< \
paddle::optional<const std::vector<const tensor_type*>>, \
Tail...> { \
template <int dev_ctx_idx, \
int in_idx, \
int attr_idx, \
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
static_assert(attr_idx == 0, \
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
paddle::optional<const std::vector<const tensor_type*>> arg = \
ctx->OptionalInputsBetween<tensor_type>(range.first, range.second); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
ctx, pargs..., arg); \
} \
}
#define PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type) \ #define PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type) \
template <typename... Tail> \ template <typename... Tail> \
struct KernelCallHelper<attr_type, Tail...> { \ struct KernelCallHelper<attr_type, Tail...> { \
...@@ -224,6 +248,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -224,6 +248,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_MULTI_INPUT(DenseTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor); PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor);
......
...@@ -890,6 +890,506 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x, ...@@ -890,6 +890,506 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
static void Interpolate1DInferShapeCheck(
const MetaTensor& x,
paddle::optional<const MetaTensor&> out_size,
paddle::optional<const std::vector<const MetaTensor*>> size_tensor,
paddle::optional<const MetaTensor&> scale_tensor,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
MetaTensor* output,
MetaConfig config) {
auto dim_x = x.dims();
PADDLE_ENFORCE_EQ("linear",
interp_method,
phi::errors::InvalidArgument(
"Interpolation method can only be \"linear\" when"
"Input(X) dimension is 3, but got method = %s .",
interp_method));
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
for (int i = 0; i < dim_x.size(); ++i) {
PADDLE_ENFORCE_NE(
dim_x[i],
0,
phi::errors::InvalidArgument("The shape of input(x) should be larged "
"than 0, bug received shape[%d] is %d ",
i,
dim_x[i]));
}
if (size_tensor && size_tensor->size() > 0) {
// top prority size
auto inputs_name = size_tensor.get();
PADDLE_ENFORCE_EQ(
inputs_name.size(),
1,
phi::errors::InvalidArgument(
"Input(SizeTensor)'size of Op(interpolate) must be 1. "
"Attr(out_shape)'s length must be 1 for 3-D input tensor, but got "
"size = %d .",
inputs_name.size()));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_w};
} else {
dim_out = {dim_x[0], out_w, dim_x[2]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
return;
}
int out_w_tmp;
if (scale_tensor) {
auto scale_tensor_dim = scale_tensor->dims();
PADDLE_ENFORCE_EQ(
scale_tensor_dim.size(),
1,
phi::errors::InvalidArgument(
"Scale's dimension size must be 1, but got dimension = %d .",
scale_tensor_dim.size()));
PADDLE_ENFORCE_EQ(scale_tensor_dim[0],
1,
phi::errors::InvalidArgument(
"Scale's shape must be 1, but got shape = %d .",
scale_tensor_dim[0]));
out_w_tmp = -1;
} else {
if (scale.size() > 0) {
float scale_w = -1;
scale_w = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
phi::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
if (scale_w > 0.) {
// round down
out_w_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale_w)
: static_cast<int>(dim_x[1] * scale_w));
// protect when input shape is -1
out_w_tmp = out_w_tmp > 0 ? out_w_tmp : -1;
}
} else {
out_w_tmp = out_w;
}
}
if (out_size && config.is_runtime) {
auto out_size_dim = out_size->dims();
PADDLE_ENFORCE_EQ(
out_size_dim.size(),
1,
phi::errors::InvalidArgument(
"OutSize's dimension size must be 1, but got dimention = %d .",
out_size_dim.size()));
PADDLE_ENFORCE_EQ(
out_size_dim[0],
1,
phi::errors::InvalidArgument(
"OutSize's 0-th dimension's value must be 1, but got value = %d .",
out_size_dim[0]));
// dims will be seted in kernel
output->set_dtype(x.dtype());
output->share_lod(x);
return;
}
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_w_tmp};
} else {
dim_out = {dim_x[0], out_w_tmp, dim_x[2]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
}
static void Interpolate2DInferShapeCheck(
const MetaTensor& x,
paddle::optional<const MetaTensor&> out_size,
paddle::optional<const std::vector<const MetaTensor*>> size_tensor,
paddle::optional<const MetaTensor&> scale_tensor,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
MetaTensor* output,
MetaConfig config) {
auto dim_x = x.dims();
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method ||
"bicubic" == interp_method,
phi::errors::InvalidArgument(
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4, but got method = %s.",
interp_method));
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
for (int i = 0; i < dim_x.size(); ++i) {
PADDLE_ENFORCE_NE(
dim_x[i],
0,
phi::errors::InvalidArgument("The shape of input(x) should be larged "
"than 0, bug received shape[%d] is %d ",
i,
dim_x[i]));
}
if (size_tensor && size_tensor->size()) {
// top prority size
auto inputs_name = size_tensor.get();
PADDLE_ENFORCE_EQ(
inputs_name.size(),
2,
phi::errors::InvalidArgument(
"Input(SizeTensor)'size of Op(interpolate) must be 2. "
"Attr(out_shape)'s length must be 2 for 4-D input "
"tensor, but got size = %d .",
inputs_name.size()));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_h, out_w};
} else {
dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
return;
}
int out_h_tmp, out_w_tmp;
if (scale_tensor) {
auto scale_tensor_dim = scale_tensor->dims();
PADDLE_ENFORCE_EQ(
scale_tensor_dim.size(),
1,
phi::errors::InvalidArgument(
"Scale's dimension size must be 1, but got dimension = %d .",
scale_tensor_dim.size()));
PADDLE_ENFORCE_EQ(scale_tensor_dim[0] == 2 || scale_tensor_dim[0] == 1,
true,
phi::errors::InvalidArgument(
"Scale's shape must be 2 or 1, but got shape = %d .",
scale_tensor_dim[0]));
out_h_tmp = -1;
out_w_tmp = -1;
} else {
if (scale.size() > 0) {
float scale_h = -1;
float scale_w = -1;
scale_h = scale[0];
scale_w = scale[1];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
phi::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
phi::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
if (scale_h > 0. && scale_w > 0.) {
// round down
out_h_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale_h)
: static_cast<int>(dim_x[1] * scale_h));
out_w_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[3] * scale_w)
: static_cast<int>(dim_x[2] * scale_w));
// protect when input shape is -1
out_h_tmp = out_h_tmp > 0 ? out_h_tmp : -1;
out_w_tmp = out_w_tmp > 0 ? out_w_tmp : -1;
}
} else {
out_h_tmp = out_h;
out_w_tmp = out_w;
}
}
if (out_size && config.is_runtime) {
auto out_size_dim = out_size->dims();
PADDLE_ENFORCE_EQ(
out_size_dim.size(),
1,
phi::errors::InvalidArgument(
"OutSize's dimension size must be 1, but got dimension = %d .",
out_size_dim.size()));
PADDLE_ENFORCE_EQ(
out_size_dim[0],
2,
phi::errors::InvalidArgument(
"OutSize's dim[0] must be 2, but got dimention = %d .",
out_size_dim[0]));
// dims will be seted in kernel
output->set_dtype(x.dtype());
output->share_lod(x);
return;
}
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_h_tmp, out_w_tmp};
} else {
dim_out = {dim_x[0], out_h_tmp, out_w_tmp, dim_x[3]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
}
static void Interpolate3DInferShapeCheck(
const MetaTensor& x,
paddle::optional<const MetaTensor&> out_size,
paddle::optional<const std::vector<const MetaTensor*>> size_tensor,
paddle::optional<const MetaTensor&> scale_tensor,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
MetaTensor* output,
MetaConfig config) {
auto dim_x = x.dims();
PADDLE_ENFORCE("nearest" == interp_method || "trilinear" == interp_method,
phi::errors::InvalidArgument(
"Interpolation method can only be \"trilinear\" or "
"\"nearest\" when Input(X) "
"dimension is 5, but got method = %s .",
interp_method));
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
for (int i = 0; i < dim_x.size(); ++i) {
PADDLE_ENFORCE_NE(
dim_x[i],
0,
phi::errors::InvalidArgument("The shape of input(x) should be larged "
"than 0, bug received shape[%d] is %d ",
i,
dim_x[i]));
}
if (size_tensor && size_tensor->size() > 0) {
// top prority size
auto inputs_name = size_tensor.get();
PADDLE_ENFORCE_EQ(
inputs_name.size(),
3,
phi::errors::InvalidArgument(
"Input(SizeTensor)'s size of Op(interpolate) must be 3. "
"Attr(out_shape)'s length must be 3 for 5-D input "
"tensor, but got size = %d .",
inputs_name.size()));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
} else {
dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
return;
}
int out_d_tmp, out_h_tmp, out_w_tmp;
if (scale_tensor) {
auto scale_tensor_dim = scale_tensor->dims();
PADDLE_ENFORCE_EQ(
scale_tensor_dim.size(),
1,
phi::errors::InvalidArgument(
"Scale's dimension size must be 1, but got size = %d .",
scale_tensor_dim.size()));
PADDLE_ENFORCE_EQ(scale_tensor_dim[0] == 3 || scale_tensor_dim[0] == 1,
true,
phi::errors::InvalidArgument(
"Scale's shape must be 3 or 1, but got shape = %d .",
scale_tensor_dim[0]));
out_d_tmp = -1;
out_h_tmp = -1;
out_w_tmp = -1;
} else {
if (scale.size() > 0) {
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
phi::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
phi::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0,
true,
phi::errors::InvalidArgument(
"The scale_d in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
// round down
out_d_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale_d)
: static_cast<int>(dim_x[1] * scale_d));
out_h_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[3] * scale_h)
: static_cast<int>(dim_x[2] * scale_h));
out_w_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[4] * scale_w)
: static_cast<int>(dim_x[3] * scale_w));
// protect when input shape is -1
out_d_tmp = out_d_tmp > 0 ? out_d_tmp : -1;
out_h_tmp = out_h_tmp > 0 ? out_h_tmp : -1;
out_w_tmp = out_w_tmp > 0 ? out_w_tmp : -1;
}
} else {
out_d_tmp = out_d;
out_h_tmp = out_h;
out_w_tmp = out_w;
}
}
if (out_size && config.is_runtime) {
auto out_size_dim = out_size->dims();
PADDLE_ENFORCE_EQ(
out_size_dim.size(),
1,
phi::errors::InvalidArgument(
"OutSize's dimension size must be 1, but got size is %d.",
out_size_dim.size()));
PADDLE_ENFORCE_EQ(out_size_dim[0],
3,
phi::errors::InvalidArgument(
"OutSize's dim[0] must be 3, but got size is %d.",
out_size_dim[0]));
// dims will be seted in kernel
output->set_dtype(x.dtype());
output->share_lod(x);
return;
}
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_d_tmp, out_h_tmp, out_w_tmp};
} else {
dim_out = {dim_x[0], out_d_tmp, out_h_tmp, out_w_tmp, dim_x[4]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
}
void InterpolateInferMeta(
const MetaTensor& x,
paddle::optional<const MetaTensor&> out_size,
paddle::optional<const std::vector<const MetaTensor*>> size_tensor,
paddle::optional<const MetaTensor&> scale_tensor,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
MetaTensor* output,
MetaConfig config) {
auto dim_x = x.dims(); // NCHW format
PADDLE_ENFORCE(
dim_x.size() == 3 || dim_x.size() == 4 || dim_x.size() == 5,
phi::errors::Unimplemented(
"Input(X) dimension must be 3, 4 or 5, but got dimension = %d .",
dim_x.size()));
if (dim_x.size() == 3) {
// shape check for 1D interpolate for input tensor shape NCHW
Interpolate1DInferShapeCheck(x,
out_size,
size_tensor,
scale_tensor,
data_layout_str,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output,
config);
} else if (dim_x.size() == 4) {
// shape check for 2D interpolate for input tensor shape NCHW
Interpolate2DInferShapeCheck(x,
out_size,
size_tensor,
scale_tensor,
data_layout_str,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output,
config);
} else { // dim_x.size() == 5
// shape check for 3D interpolate for input tensor shape NCDHW
Interpolate3DInferShapeCheck(x,
out_size,
size_tensor,
scale_tensor,
data_layout_str,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output,
config);
}
}
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) { void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x); auto inputs_dims = GetMetaTensorsDim(x);
......
...@@ -199,6 +199,22 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x, ...@@ -199,6 +199,22 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x,
MetaTensor* pre_out, MetaTensor* pre_out,
MetaTensor* w_out); MetaTensor* w_out);
void InterpolateInferMeta(
const MetaTensor& x,
paddle::optional<const MetaTensor&> out_size,
paddle::optional<const std::vector<const MetaTensor*>> size_tensor,
paddle::optional<const MetaTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
MetaTensor* output,
MetaConfig config = MetaConfig());
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out); void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins, void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/interpolate_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
static void LinearInterpolationGrad(const DenseTensor& output_grad,
DenseTensor* input_grad,
const float ratio_w,
const int in_w,
const int n,
const int c,
const int out_w,
const bool align_corners,
const int align_mode,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 3>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 3>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int l = 0; l < out_w; l++) {
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0; // w
int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w; // w_id
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; // w1lambda
float d_e = 1.f - d_w; // w2lambda
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// linear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(i, j, l);
input_grad_t(i, j, x_w) += static_cast<T>(grad * d_e);
input_grad_t(i, j, x_e) += static_cast<T>(grad * d_w);
} else {
const T grad = output_grad_t(i, l, j);
input_grad_t(i, x_w, j) += static_cast<T>(grad * d_e);
input_grad_t(i, x_e, j) += static_cast<T>(grad * d_w);
}
}
}
}
}
template <typename T>
static void BilinearInterpolationGrad(const DenseTensor& output_grad,
DenseTensor* input_grad,
const float ratio_h,
const float ratio_w,
const int in_h,
const int in_w,
const int n,
const int c,
const int out_h,
const int out_w,
const bool align_corners,
const int align_mode,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int k = 0; k < out_h; k++) { // loop for images
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) {
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bilinear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, y_n, x_w) += static_cast<T>(grad * d_s * d_e);
input_grad_t(i, j, y_s, x_w) += static_cast<T>(grad * d_n * d_e);
input_grad_t(i, j, y_n, x_e) += static_cast<T>(grad * d_s * d_w);
input_grad_t(i, j, y_s, x_e) += static_cast<T>(grad * d_n * d_w);
} else {
const T grad = output_grad_t(i, k, l, j);
input_grad_t(i, y_n, x_w, j) += static_cast<T>(grad * d_s * d_e);
input_grad_t(i, y_s, x_w, j) += static_cast<T>(grad * d_n * d_e);
input_grad_t(i, y_n, x_e, j) += static_cast<T>(grad * d_s * d_w);
input_grad_t(i, y_s, x_e, j) += static_cast<T>(grad * d_n * d_w);
}
}
}
}
}
}
template <typename T>
static void NearestNeighborInterpolateGrad(const DenseTensor& output_grad,
DenseTensor* input_grad,
const float ratio_h,
const float ratio_w,
const int n,
const int c,
const int out_h,
const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
} else {
input_grad_t(i, in_k, in_l, j) += output_grad_t(i, k, l, j);
}
}
}
}
}
}
template <typename T>
static void BicubicInterpolationGrad(const DenseTensor& output_grad,
DenseTensor* input_grad,
const float ratio_h,
const float ratio_w,
const int in_h,
const int in_w,
const int n,
const int c,
const int out_h,
const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
T y_n = align_corners ? static_cast<T>(ratio_h * k)
: static_cast<T>(ratio_h * (k + 0.5) - 0.5);
int input_y = floorf(y_n);
T y_t = y_n - input_y;
for (int l = 0; l < out_w; l++) {
T x_n = align_corners ? static_cast<T>(ratio_w * l)
: static_cast<T>(ratio_w * (l + 0.5) - 0.5);
int input_x = floorf(x_n);
T x_t = x_n - input_x;
T x_coeffs[4];
T y_coeffs[4];
funcs::get_cubic_upsample_coefficients<T>(x_coeffs, x_t);
funcs::get_cubic_upsample_coefficients<T>(y_coeffs, y_t);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bicubic interpolation grad
for (int ii = 0; ii < 4; ii++) {
for (int jj = 0; jj < 4; jj++) {
int access_x = std::max(std::min(input_x - 1 + ii, in_w - 1),
static_cast<int>(0));
int access_y = std::max(std::min(input_y - 1 + jj, in_h - 1),
static_cast<int>(0));
if (data_layout == DataLayout::kNCHW) {
T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, access_y, access_x) +=
grad * y_coeffs[jj] * x_coeffs[ii];
} else {
T grad = output_grad_t(i, k, l, j);
input_grad_t(i, access_y, access_x, j) +=
grad * y_coeffs[jj] * x_coeffs[ii];
}
}
}
}
}
}
}
}
template <typename T>
static void TrilinearInterpolationGrad(const DenseTensor& output_grad,
DenseTensor* input_grad,
const float ratio_d,
const float ratio_h,
const float ratio_w,
const int in_d,
const int in_h,
const int in_w,
const int n,
const int c,
const int out_d,
const int out_h,
const int out_w,
const bool align_corners,
const int align_mode,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int j = 0; j < out_d; j++) { // loop for D
int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
: static_cast<int>(ratio_d * j);
t_f = (t_f > 0) ? t_f : 0;
int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
float idx_src_t = ratio_d * (j + 0.5) - 0.5;
idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
float d_b = 1.f - d_f;
for (int k = 0; k < out_h; k++) { // loop for H
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) { // loop for W
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int b = 0; b < n; b++) { // loop for batches
for (int i = 0; i < c; i++) { // loop for channels
// trilinear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(b, i, j, k, l);
input_grad_t(b, i, t_f, y_n, x_w) +=
static_cast<T>(grad * d_b * d_s * d_e);
input_grad_t(b, i, t_f, y_n, x_e) +=
static_cast<T>(grad * d_b * d_s * d_w);
input_grad_t(b, i, t_f, y_s, x_w) +=
static_cast<T>(grad * d_b * d_n * d_e);
input_grad_t(b, i, t_f, y_s, x_e) +=
static_cast<T>(grad * d_b * d_n * d_w);
input_grad_t(b, i, t_b, y_n, x_w) +=
static_cast<T>(grad * d_f * d_s * d_e);
input_grad_t(b, i, t_b, y_n, x_e) +=
static_cast<T>(grad * d_f * d_s * d_w);
input_grad_t(b, i, t_b, y_s, x_w) +=
static_cast<T>(grad * d_f * d_n * d_e);
input_grad_t(b, i, t_b, y_s, x_e) +=
static_cast<T>(grad * d_f * d_n * d_w);
} else {
const T grad = output_grad_t(b, j, k, l, i);
input_grad_t(b, t_f, y_n, x_w, i) +=
static_cast<T>(grad * d_b * d_s * d_e);
input_grad_t(b, t_f, y_n, x_e, i) +=
static_cast<T>(grad * d_b * d_s * d_w);
input_grad_t(b, t_f, y_s, x_w, i) +=
static_cast<T>(grad * d_b * d_n * d_e);
input_grad_t(b, t_f, y_s, x_e, i) +=
static_cast<T>(grad * d_b * d_n * d_w);
input_grad_t(b, t_b, y_n, x_w, i) +=
static_cast<T>(grad * d_f * d_s * d_e);
input_grad_t(b, t_b, y_n, x_e, i) +=
static_cast<T>(grad * d_f * d_s * d_w);
input_grad_t(b, t_b, y_s, x_w, i) +=
static_cast<T>(grad * d_f * d_n * d_e);
input_grad_t(b, t_b, y_s, x_e, i) +=
static_cast<T>(grad * d_f * d_n * d_w);
}
}
}
}
}
}
}
template <typename T>
static void NearestNeighbor3DInterpolateGrad(const DenseTensor& output_grad,
DenseTensor* input_grad,
const float ratio_d,
const float ratio_h,
const float ratio_w,
const int n,
const int c,
const int out_d,
const int out_h,
const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
for (int d = 0; d < out_d; d++) {
int in_d = (align_corners) ? static_cast<int>(ratio_d * d + 0.5)
: static_cast<int>(ratio_d * d);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
input_grad_t(i, j, in_d, in_k, in_l) +=
output_grad_t(i, j, d, k, l);
} else {
input_grad_t(i, in_d, in_k, in_l, j) +=
output_grad_t(i, d, k, l, j);
}
}
}
}
}
}
}
template <typename T, typename Context>
static void Interpolate1DCPUBwd(
const Context& dev_ctx,
const DenseTensor& input,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& output_grad,
const std::string& data_layout_str,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* input_grad) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
float scale_w = -1.0;
if (scale_tensor) {
auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
} else {
if (scale.size() > 0) {
scale_w = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
}
}
if (scale_w > 0.) {
out_w = static_cast<int>(in_w * scale_w);
}
if (out_size) {
auto out_size_data =
funcs::get_new_data_from_tensor<int>(out_size.get_ptr());
out_w = out_size_data[0];
}
if (size_tensor && size_tensor->size() > 0) {
// have size tensor
auto new_size = funcs::get_new_shape(size_tensor.get());
out_w = new_size[0];
}
phi::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_w};
} else {
dim_grad = {n, in_w, c};
}
input_grad->Resize(dim_grad);
dev_ctx.template Alloc<T>(input_grad);
phi::funcs::SetConstant<Context, T> zero;
zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
if ("linear" == interp_method) {
LinearInterpolationGrad<T>(output_grad,
input_grad,
ratio_w,
in_w,
n,
c,
out_w,
align_corners,
align_mode,
data_layout);
}
}
template <typename T, typename Context>
static void Interpolate2DCPUBwd(
const Context& dev_ctx,
const DenseTensor& input,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& output_grad,
const std::string& data_layout_str,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* input_grad) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
float scale_h = -1;
float scale_w = -1;
if (scale_tensor) {
auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_w = scale_data[0];
scale_h = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
} else {
if (scale.size() > 1) {
scale_h = scale[0];
scale_w = scale[1];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
}
}
if (scale_h > 0. && scale_w > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
if (out_size) {
auto out_size_data =
funcs::get_new_data_from_tensor<int>(out_size.get_ptr());
out_h = out_size_data[0];
out_w = out_size_data[1];
}
if (size_tensor && size_tensor->size() > 0) {
// have size tensor
auto new_size = funcs::get_new_shape(size_tensor.get());
out_h = new_size[0];
out_w = new_size[1];
}
phi::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_h, in_w};
} else {
dim_grad = {n, in_h, in_w, c};
}
input_grad->Resize(dim_grad);
dev_ctx.template Alloc<T>(input_grad);
phi::funcs::SetConstant<Context, T> zero;
zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
if ("bilinear" == interp_method) {
BilinearInterpolationGrad<T>(output_grad,
input_grad,
ratio_h,
ratio_w,
in_h,
in_w,
n,
c,
out_h,
out_w,
align_corners,
align_mode,
data_layout);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolateGrad<T>(output_grad,
input_grad,
ratio_h,
ratio_w,
n,
c,
out_h,
out_w,
align_corners,
data_layout);
} else if ("bicubic" == interp_method) {
BicubicInterpolationGrad<T>(output_grad,
input_grad,
ratio_h,
ratio_w,
in_h,
in_w,
n,
c,
out_h,
out_w,
align_corners,
data_layout);
}
}
template <typename T, typename Context>
static void Interpolate3DCPUBwd(
const Context& dev_ctx,
const DenseTensor& input,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& output_grad,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* input_grad) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
if (scale_tensor) {
auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
if (scale_data.size() > 1) {
scale_d = scale_data[0];
scale_h = scale_data[1];
scale_w = scale_data[2];
} else {
scale_d = scale_data[0];
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0,
true,
errors::InvalidArgument(
"The scale_d in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
} else {
if (scale.size() > 1) {
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0,
true,
errors::InvalidArgument(
"The scale_d in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
}
}
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
out_d = static_cast<int>(in_d * scale_d);
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
if (out_size) {
auto out_size_data =
funcs::get_new_data_from_tensor<int>(out_size.get_ptr());
out_d = out_size_data[0];
out_h = out_size_data[1];
out_w = out_size_data[2];
}
if (size_tensor && size_tensor->size() > 0) {
// have size tensor
auto new_size = funcs::get_new_shape(size_tensor.get());
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
}
phi::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_d, in_h, in_w};
} else {
dim_grad = {n, in_d, in_h, in_w, c};
}
input_grad->Resize(dim_grad);
dev_ctx.template Alloc<T>(input_grad);
phi::funcs::SetConstant<Context, T> zero;
zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
float new_scale_d = 0.f;
new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
: static_cast<float>(in_d) / out_d;
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(new_scale_d);
}
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
if ("trilinear" == interp_method) {
TrilinearInterpolationGrad<T>(output_grad,
input_grad,
ratio_d,
ratio_h,
ratio_w,
in_d,
in_h,
in_w,
n,
c,
out_d,
out_h,
out_w,
align_corners,
align_mode,
data_layout);
} else if ("nearest" == interp_method) {
NearestNeighbor3DInterpolateGrad<T>(output_grad,
input_grad,
ratio_d,
ratio_h,
ratio_w,
n,
c,
out_d,
out_h,
out_w,
align_corners,
data_layout);
}
}
template <typename T, typename Context>
void InterpolateGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& output_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
auto output_grad_dims = output_grad.dims();
if (output_grad_dims.size() == 3) { // 1D interpolation grad
Interpolate1DCPUBwd<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
output_grad,
data_layout,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
} else if (output_grad_dims.size() == 4) { // 2D interpolation grad
Interpolate2DCPUBwd<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
output_grad,
data_layout,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
} else if (output_grad_dims.size() == 5) { // 3D interpolation grad
Interpolate3DCPUBwd<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
output_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
}
template <typename T, typename Context>
void BilinearInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
InterpolateGradKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
out_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
template <typename T, typename Context>
void NearestInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
InterpolateGradKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
out_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
template <typename T, typename Context>
void TrilinearInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
InterpolateGradKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
out_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
template <typename T, typename Context>
void LinearInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
InterpolateGradKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
out_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
template <typename T, typename Context>
void BicubicInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
InterpolateGradKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
out_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(bilinear_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::BilinearInterpGradKernel,
float,
double) {}
PD_REGISTER_KERNEL(nearest_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::NearestInterpGradKernel,
float,
double) {}
PD_REGISTER_KERNEL(trilinear_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::TrilinearInterpGradKernel,
float,
double) {}
PD_REGISTER_KERNEL(linear_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::LinearInterpGradKernel,
float,
double) {}
PD_REGISTER_KERNEL(bicubic_interp_v2_grad,
CPU,
ALL_LAYOUT,
phi::BicubicInterpGradKernel,
float,
double) {}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); //
you may not use this file except in compliance with the License. // Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at // you may not use this file except in compliance with the License.
http://www.apache.org/licenses/LICENSE-2.0 // You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software //
distributed under the License is distributed on an "AS IS" BASIS, // http://www.apache.org/licenses/LICENSE-2.0
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
See the License for the specific language governing permissions and // Unless required by applicable law or agreed to in writing, software
limitations under the License. */ // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#pragma once // See the License for the specific language governing permissions and
#include <algorithm> // limitations under the License.
#include <string>
#include <vector> #include "paddle/phi/kernels/interpolate_kernel.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { #include "paddle/phi/kernels/funcs/interpolate_function.h"
namespace operators {
namespace phi {
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
inline std::vector<int> get_new_shape(
const std::vector<const Tensor*>& list_new_shape_tensor) {
// get tensor from
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(tensor->dims(), phi::make_ddim({1}),
platform::errors::InvalidArgument(
"The shape of dimension tensor should be [1],"
"but received d%.",
tensor->dims()));
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
}
return vec_new_shape;
}
template <typename T>
inline std::vector<T> get_new_data_from_tensor(const Tensor* new_data_tensor) {
std::vector<T> vec_new_data;
auto* new_data = new_data_tensor->data<T>();
framework::Tensor cpu_starts_tensor;
if (platform::is_gpu_place(new_data_tensor->place())) {
paddle::framework::TensorCopySync(*new_data_tensor, platform::CPUPlace(),
&cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(new_data_tensor->place())) {
paddle::framework::TensorCopySync(*new_data_tensor, platform::CPUPlace(),
&cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(new_data_tensor->place())) {
paddle::framework::TensorCopySync(*new_data_tensor, platform::CPUPlace(),
&cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#endif
vec_new_data = std::vector<T>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
inline void ExtractNCDWH(const framework::DDim& dims,
const DataLayout& data_layout, int* N, int* C, int* D,
int* H, int* W) {
*N = dims[0];
if (dims.size() == 3) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[2];
*D = 1;
*H = 1;
*W = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
} else if (dims.size() == 4) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[3];
*D = 1;
*H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*W = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
} else {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[4];
*D = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*H = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
*W = data_layout == DataLayout::kNCHW ? dims[4] : dims[3];
}
}
template <typename T>
static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int n, const int c,
const int out_h, const int out_w,
const bool align_corners,
const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
} else {
output_t(i, k, l, j) = input_t(i, in_k, in_l, j);
}
}
}
}
}
}
template <typename T> template <typename T>
static void NearestNeighbor3DInterpolate( static inline T cubic_interp(T x0, T x1, T x2, T x3, T t) {
const Tensor& input, Tensor* output, const float ratio_d, T coeffs[4];
const float ratio_h, const float ratio_w, const int n, const int c, funcs::get_cubic_upsample_coefficients<T>(coeffs, t);
const int out_d, const int out_h, const int out_w, const bool align_corners,
const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 5>::From(input);
auto output_t = EigenTensor<T, 5>::From(*output);
for (int d = 0; d < out_d; d++) { // loop for images
int in_d = (align_corners) ? static_cast<int>(ratio_d * d + 0.5)
: static_cast<int>(ratio_d * d);
for (int k = 0; k < out_h; k++) {
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, d, k, l) = input_t(i, j, in_d, in_k, in_l);
} else { // NDHWC
output_t(i, d, k, l, j) = input_t(i, in_d, in_k, in_l, j);
}
}
}
}
}
}
} }
template <typename T> template <typename T>
static void LinearInterpolation(const Tensor& input, Tensor* output, static void LinearInterpolation(const DenseTensor& input,
const float ratio_w, const int in_w, DenseTensor* output,
const int n, const int c, const int out_w, const float ratio_w,
const bool align_corners, const bool align_mode, const int in_w,
const int n,
const int c,
const int out_w,
const bool align_corners,
const int align_mode,
const DataLayout data_layout) { const DataLayout data_layout) {
auto input_t = EigenTensor<T, 3>::From(input); auto input_t = EigenTensor<T, 3>::From(input);
auto output_t = EigenTensor<T, 3>::From(*output); auto output_t = EigenTensor<T, 3>::From(*output);
...@@ -223,50 +94,18 @@ static void LinearInterpolation(const Tensor& input, Tensor* output, ...@@ -223,50 +94,18 @@ static void LinearInterpolation(const Tensor& input, Tensor* output,
} }
template <typename T> template <typename T>
static void LinearInterpolationGrad(const Tensor& output_grad, static void BilinearInterpolation(const DenseTensor& input,
Tensor* input_grad, const float ratio_w, DenseTensor* output,
const int in_w, const int n, const int c, const float ratio_h,
const int out_w, const bool align_corners, const float ratio_w,
const int align_mode, const int in_h,
const DataLayout data_layout) { const int in_w,
auto input_grad_t = EigenTensor<T, 3>::From(*input_grad); const int n,
auto output_grad_t = EigenTensor<T, 3>::From(output_grad); const int c,
bool align_flag = (align_mode == 0 && !align_corners); const int out_h,
for (int l = 0; l < out_w; l++) { const int out_w,
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0; // w
int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w; // w_id
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; // w1lambda
float d_e = 1.f - d_w; // w2lambda
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// linear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(i, j, l);
input_grad_t(i, j, x_w) += static_cast<T>(grad * d_e);
input_grad_t(i, j, x_e) += static_cast<T>(grad * d_w);
} else {
const T grad = output_grad_t(i, l, j);
input_grad_t(i, x_w, j) += static_cast<T>(grad * d_e);
input_grad_t(i, x_e, j) += static_cast<T>(grad * d_w);
}
}
}
}
}
template <typename T>
static void BilinearInterpolation(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int in_h, const int in_w, const int n,
const int c, const int out_h, const int out_w,
const bool align_corners, const bool align_corners,
const bool align_mode, const int align_mode,
const DataLayout data_layout) { const DataLayout data_layout) {
auto input_t = EigenTensor<T, 4>::From(input); auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output); auto output_t = EigenTensor<T, 4>::From(*output);
...@@ -355,12 +194,136 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, ...@@ -355,12 +194,136 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output,
} }
template <typename T> template <typename T>
static void TrilinearInterpolation( static void NearestNeighborInterpolate(const DenseTensor& input,
const Tensor& input, Tensor* output, const float ratio_d, DenseTensor* output,
const float ratio_h, const float ratio_w, const int in_d, const int in_h, const float ratio_h,
const int in_w, const int n, const int c, const int out_d, const int out_h, const float ratio_w,
const int out_w, const bool align_corners, const bool align_mode, const int n,
const DataLayout& data_layout) { const int c,
const int out_h,
const int out_w,
const bool align_corners,
const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
} else {
output_t(i, k, l, j) = input_t(i, in_k, in_l, j);
}
}
}
}
}
}
template <typename T>
static void BicubicInterpolation(const DenseTensor& input,
DenseTensor* output,
const float ratio_h,
const float ratio_w,
const int in_h,
const int in_w,
const int n,
const int c,
const int out_h,
const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
T y_n = align_corners ? static_cast<T>(ratio_h * k)
: static_cast<T>(ratio_h * (k + 0.5) - 0.5);
int input_y = floorf(y_n);
const T y_t = y_n - input_y;
for (int l = 0; l < out_w; l++) {
T x_n = align_corners ? static_cast<T>(ratio_w * l)
: static_cast<T>(ratio_w * (l + 0.5) - 0.5);
int input_x = floorf(x_n);
const T x_t = x_n - input_x;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
T coefficients[4];
// interp 4 times in x direction
for (int ii = 0; ii < 4; ii++) {
int access_y = std::max(std::min(input_y - 1 + ii, in_h - 1),
static_cast<int>(0));
int access_x_0 =
std::max(std::min(input_x - 1, in_w - 1), static_cast<int>(0));
int access_x_1 =
std::max(std::min(input_x + 0, in_w - 1), static_cast<int>(0));
int access_x_2 =
std::max(std::min(input_x + 1, in_w - 1), static_cast<int>(0));
int access_x_3 =
std::max(std::min(input_x + 2, in_w - 1), static_cast<int>(0));
if (data_layout == DataLayout::kNCHW) {
coefficients[ii] =
cubic_interp<T>(input_t(i, j, access_y, access_x_0),
input_t(i, j, access_y, access_x_1),
input_t(i, j, access_y, access_x_2),
input_t(i, j, access_y, access_x_3),
x_t);
} else {
coefficients[ii] =
cubic_interp<T>(input_t(i, access_y, access_x_0, j),
input_t(i, access_y, access_x_1, j),
input_t(i, access_y, access_x_2, j),
input_t(i, access_y, access_x_3, j),
x_t);
}
}
// interp y direction
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, k, l) = cubic_interp<T>(coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
y_t);
} else {
output_t(i, k, l, j) = cubic_interp<T>(coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
y_t);
}
}
}
}
}
}
template <typename T>
static void TrilinearInterpolation(const DenseTensor& input,
DenseTensor* output,
const float ratio_d,
const float ratio_h,
const float ratio_w,
const int in_d,
const int in_h,
const int in_w,
const int n,
const int c,
const int out_d,
const int out_h,
const int out_w,
const bool align_corners,
const int align_mode,
const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 5>::From(input); auto input_t = EigenTensor<T, 5>::From(input);
auto output_t = EigenTensor<T, 5>::From(*output); auto output_t = EigenTensor<T, 5>::From(*output);
bool align_flag = (align_mode == 0 && !align_corners); bool align_flag = (align_mode == 0 && !align_corners);
...@@ -498,392 +461,78 @@ static void TrilinearInterpolation( ...@@ -498,392 +461,78 @@ static void TrilinearInterpolation(
} }
template <typename T> template <typename T>
HOSTDEVICE inline T cubic_convolution1(T x, T A) { static void NearestNeighbor3DInterpolate(const DenseTensor& input,
return ((A + 2) * x - (A + 3)) * x * x + 1; DenseTensor* output,
} const float ratio_d,
const float ratio_h,
template <typename T> const float ratio_w,
HOSTDEVICE inline T cubic_convolution2(T x, T A) { const int n,
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; const int c,
} const int out_d,
const int out_h,
template <typename T> const int out_w,
HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) { const bool align_corners,
T A = -0.75; const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 5>::From(input);
T x1 = t; auto output_t = EigenTensor<T, 5>::From(*output);
coeffs[0] = cubic_convolution2<T>(x1 + 1.0, A); for (int d = 0; d < out_d; d++) { // loop for images
coeffs[1] = cubic_convolution1<T>(x1, A); int in_d = (align_corners) ? static_cast<int>(ratio_d * d + 0.5)
: static_cast<int>(ratio_d * d);
// opposite coefficients for (int k = 0; k < out_h; k++) {
T x2 = 1.0 - t; int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
coeffs[2] = cubic_convolution1<T>(x2, A); : static_cast<int>(ratio_h * k);
coeffs[3] = cubic_convolution2<T>(x2 + 1.0, A);
}
template <typename T>
static inline T cubic_interp(T x0, T x1, T x2, T x3, T t) {
T coeffs[4];
get_cubic_upsample_coefficients<T>(coeffs, t);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename T>
static void BicubicInterpolation(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int in_h, const int in_w, const int n,
const int c, const int out_h, const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
T y_n = align_corners ? static_cast<T>(ratio_h * k)
: static_cast<T>(ratio_h * (k + 0.5) - 0.5);
int input_y = floorf(y_n);
const T y_t = y_n - input_y;
for (int l = 0; l < out_w; l++) { for (int l = 0; l < out_w; l++) {
T x_n = align_corners ? static_cast<T>(ratio_w * l) int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<T>(ratio_w * (l + 0.5) - 0.5); : static_cast<int>(ratio_w * l);
int input_x = floorf(x_n);
const T x_t = x_n - input_x;
for (int i = 0; i < n; i++) { // loop for batches for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels for (int j = 0; j < c; j++) { // loop for channels
T coefficients[4];
// interp 4 times in x direction
for (int ii = 0; ii < 4; ii++) {
int access_y = std::max(std::min(input_y - 1 + ii, in_h - 1),
static_cast<int>(0));
int access_x_0 =
std::max(std::min(input_x - 1, in_w - 1), static_cast<int>(0));
int access_x_1 =
std::max(std::min(input_x + 0, in_w - 1), static_cast<int>(0));
int access_x_2 =
std::max(std::min(input_x + 1, in_w - 1), static_cast<int>(0));
int access_x_3 =
std::max(std::min(input_x + 2, in_w - 1), static_cast<int>(0));
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
coefficients[ii] = output_t(i, j, d, k, l) = input_t(i, j, in_d, in_k, in_l);
cubic_interp<T>(input_t(i, j, access_y, access_x_0), } else { // NDHWC
input_t(i, j, access_y, access_x_1), output_t(i, d, k, l, j) = input_t(i, in_d, in_k, in_l, j);
input_t(i, j, access_y, access_x_2),
input_t(i, j, access_y, access_x_3), x_t);
} else {
coefficients[ii] =
cubic_interp<T>(input_t(i, access_y, access_x_0, j),
input_t(i, access_y, access_x_1, j),
input_t(i, access_y, access_x_2, j),
input_t(i, access_y, access_x_3, j), x_t);
} }
} }
// interp y direction
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, k, l) =
cubic_interp<T>(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t);
} else {
output_t(i, k, l, j) =
cubic_interp<T>(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t);
}
} }
} }
} }
} }
} }
template <typename T> template <typename T, typename Context>
static void NearestNeighborInterpolateGrad( static void Interpolate1DCPUFwd(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h, const Context& dev_ctx,
const float ratio_w, const int n, const int c, const int out_h, const DenseTensor& x,
const int out_w, const bool align_corners, const DataLayout data_layout) { paddle::optional<const DenseTensor&> out_size,
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad); paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
auto output_grad_t = EigenTensor<T, 4>::From(output_grad); paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout_str,
for (int k = 0; k < out_h; k++) { // loop for images int out_w,
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5) const std::vector<float>& scale,
: static_cast<int>(ratio_h * k); const std::string& interp_method,
bool align_corners,
for (int l = 0; l < out_w; l++) { int align_mode,
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5) DenseTensor* output) {
: static_cast<int>(ratio_w * l); const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
} else {
input_grad_t(i, in_k, in_l, j) += output_grad_t(i, k, l, j);
}
}
}
}
}
}
template <typename T>
static void NearestNeighbor3DInterpolateGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
const float ratio_h, const float ratio_w, const int n, const int c,
const int out_d, const int out_h, const int out_w, const bool align_corners,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
for (int d = 0; d < out_d; d++) {
int in_d = (align_corners) ? static_cast<int>(ratio_d * d + 0.5)
: static_cast<int>(ratio_d * d);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
input_grad_t(i, j, in_d, in_k, in_l) +=
output_grad_t(i, j, d, k, l);
} else {
input_grad_t(i, in_d, in_k, in_l, j) +=
output_grad_t(i, d, k, l, j);
}
}
}
}
}
}
}
template <typename T>
static void BilinearInterpolationGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
const float ratio_w, const int in_h, const int in_w, const int n,
const int c, const int out_h, const int out_w, const bool align_corners,
const int align_mode, const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int k = 0; k < out_h; k++) { // loop for images
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) {
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bilinear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, y_n, x_w) += static_cast<T>(grad * d_s * d_e);
input_grad_t(i, j, y_s, x_w) += static_cast<T>(grad * d_n * d_e);
input_grad_t(i, j, y_n, x_e) += static_cast<T>(grad * d_s * d_w);
input_grad_t(i, j, y_s, x_e) += static_cast<T>(grad * d_n * d_w);
} else {
const T grad = output_grad_t(i, k, l, j);
input_grad_t(i, y_n, x_w, j) += static_cast<T>(grad * d_s * d_e);
input_grad_t(i, y_s, x_w, j) += static_cast<T>(grad * d_n * d_e);
input_grad_t(i, y_n, x_e, j) += static_cast<T>(grad * d_s * d_w);
input_grad_t(i, y_s, x_e, j) += static_cast<T>(grad * d_n * d_w);
}
}
}
}
}
}
template <typename T>
static void TrilinearInterpolationGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
const float ratio_h, const float ratio_w, const int in_d, const int in_h,
const int in_w, const int n, const int c, const int out_d, const int out_h,
const int out_w, const bool align_corners, const int align_mode,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int j = 0; j < out_d; j++) { // loop for D
int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
: static_cast<int>(ratio_d * j);
t_f = (t_f > 0) ? t_f : 0;
int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
float idx_src_t = ratio_d * (j + 0.5) - 0.5;
idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
float d_b = 1.f - d_f;
for (int k = 0; k < out_h; k++) { // loop for H
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) { // loop for W
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int b = 0; b < n; b++) { // loop for batches
for (int i = 0; i < c; i++) { // loop for channels
// trilinear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(b, i, j, k, l);
input_grad_t(b, i, t_f, y_n, x_w) +=
static_cast<T>(grad * d_b * d_s * d_e);
input_grad_t(b, i, t_f, y_n, x_e) +=
static_cast<T>(grad * d_b * d_s * d_w);
input_grad_t(b, i, t_f, y_s, x_w) +=
static_cast<T>(grad * d_b * d_n * d_e);
input_grad_t(b, i, t_f, y_s, x_e) +=
static_cast<T>(grad * d_b * d_n * d_w);
input_grad_t(b, i, t_b, y_n, x_w) +=
static_cast<T>(grad * d_f * d_s * d_e);
input_grad_t(b, i, t_b, y_n, x_e) +=
static_cast<T>(grad * d_f * d_s * d_w);
input_grad_t(b, i, t_b, y_s, x_w) +=
static_cast<T>(grad * d_f * d_n * d_e);
input_grad_t(b, i, t_b, y_s, x_e) +=
static_cast<T>(grad * d_f * d_n * d_w);
} else {
const T grad = output_grad_t(b, j, k, l, i);
input_grad_t(b, t_f, y_n, x_w, i) +=
static_cast<T>(grad * d_b * d_s * d_e);
input_grad_t(b, t_f, y_n, x_e, i) +=
static_cast<T>(grad * d_b * d_s * d_w);
input_grad_t(b, t_f, y_s, x_w, i) +=
static_cast<T>(grad * d_b * d_n * d_e);
input_grad_t(b, t_f, y_s, x_e, i) +=
static_cast<T>(grad * d_b * d_n * d_w);
input_grad_t(b, t_b, y_n, x_w, i) +=
static_cast<T>(grad * d_f * d_s * d_e);
input_grad_t(b, t_b, y_n, x_e, i) +=
static_cast<T>(grad * d_f * d_s * d_w);
input_grad_t(b, t_b, y_s, x_w, i) +=
static_cast<T>(grad * d_f * d_n * d_e);
input_grad_t(b, t_b, y_s, x_e, i) +=
static_cast<T>(grad * d_f * d_n * d_w);
}
}
}
}
}
}
}
template <typename T>
static void BicubicInterpolationGrad(const Tensor& output_grad,
Tensor* input_grad, const float ratio_h,
const float ratio_w, const int in_h,
const int in_w, const int n, const int c,
const int out_h, const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
T y_n = align_corners ? static_cast<T>(ratio_h * k)
: static_cast<T>(ratio_h * (k + 0.5) - 0.5);
int input_y = floorf(y_n);
T y_t = y_n - input_y;
for (int l = 0; l < out_w; l++) {
T x_n = align_corners ? static_cast<T>(ratio_w * l)
: static_cast<T>(ratio_w * (l + 0.5) - 0.5);
int input_x = floorf(x_n);
T x_t = x_n - input_x;
T x_coeffs[4];
T y_coeffs[4];
get_cubic_upsample_coefficients<T>(x_coeffs, x_t);
get_cubic_upsample_coefficients<T>(y_coeffs, y_t);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bicubic interpolation grad
for (int ii = 0; ii < 4; ii++) {
for (int jj = 0; jj < 4; jj++) {
int access_x = std::max(std::min(input_x - 1 + ii, in_w - 1),
static_cast<int>(0));
int access_y = std::max(std::min(input_y - 1 + jj, in_h - 1),
static_cast<int>(0));
if (data_layout == DataLayout::kNCHW) {
T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, access_y, access_x) +=
grad * y_coeffs[jj] * x_coeffs[ii];
} else {
T grad = output_grad_t(i, k, l, j);
input_grad_t(i, access_y, access_x, j) +=
grad * y_coeffs[jj] * x_coeffs[ii];
}
}
}
}
}
}
}
}
template <typename T>
static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w; int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
float scale_w = -1.; float scale_w = -1.;
if (list_new_size_tensor.size() > 0) { if (size_tensor && size_tensor->size() > 0) {
// have size tensor // have size tensor
auto new_size = get_new_shape(list_new_size_tensor); auto new_size = funcs::get_new_shape(size_tensor.get());
out_w = new_size[0]; out_w = new_size[0];
} else { } else {
// float scale_w = -1; if (scale_tensor) {
auto scale_tensor = ctx.Input<Tensor>("Scale"); auto scale_data =
auto scale = ctx.Attr<std::vector<float>>("scale"); funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale_w = scale_data[0]; scale_w = scale_data[0];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_w > 0, true, scale_w > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) " "The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_w)); scale_w));
...@@ -892,8 +541,9 @@ static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -892,8 +541,9 @@ static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
scale_w = scale[0]; scale_w = scale[0];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_w > 0, true, scale_w > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) " "The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_w)); scale_w));
...@@ -902,25 +552,28 @@ static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -902,25 +552,28 @@ static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
if (scale_w > 0.) { if (scale_w > 0.) {
out_w = static_cast<int>(in_w * scale_w); out_w = static_cast<int>(in_w * scale_w);
} }
auto out_size = ctx.Input<Tensor>("OutSize"); if (out_size) {
if (out_size != nullptr) { auto out_size_data =
auto out_size_data = get_new_data_from_tensor<int>(out_size); funcs::get_new_data_from_tensor<int>(out_size.get_ptr());
out_w = out_size_data[0]; out_w = out_size_data[0];
} }
} }
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(
"out_w in Attr(out_shape) of Op(interpolate) " out_w,
"should be greater than 0.")); 0,
framework::DDim dim_out; errors::InvalidArgument("out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_w}; dim_out = {n, c, out_w};
} else { } else {
dim_out = {n, out_w, c}; dim_out = {n, out_w, c};
} }
output->mutable_data<T>(dim_out, ctx.GetPlace()); output->Resize(dim_out);
dev_ctx.template Alloc<T>(output);
if (in_w == out_w) { if (in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output); paddle::framework::TensorCopy(x, dev_ctx.GetPlace(), output);
return; return;
} }
...@@ -933,39 +586,51 @@ static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -933,39 +586,51 @@ static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
: static_cast<float>(new_scale_w); : static_cast<float>(new_scale_w);
} }
if ("linear" == interp_method) { if ("linear" == interp_method) {
LinearInterpolation<T>(input, output, ratio_w, in_w, n, c, out_w, LinearInterpolation<T>(x,
align_corners, align_mode, data_layout); output,
ratio_w,
in_w,
n,
c,
out_w,
align_corners,
align_mode,
data_layout);
} }
} }
template <typename T> template <typename T, typename Context>
static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, static void Interpolate2DCPUFwd(
const Tensor& input, Tensor* output) { const Context& dev_ctx,
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const DenseTensor& x,
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout_str,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w; int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_h = -1; float scale_h = -1;
float scale_w = -1; float scale_w = -1;
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor"); if (size_tensor && size_tensor->size() > 0) {
if (list_new_size_tensor.size() > 0) {
// have size tensor // have size tensor
auto new_size = get_new_shape(list_new_size_tensor); auto new_size = funcs::get_new_shape(size_tensor.get());
out_h = new_size[0]; out_h = new_size[0];
out_w = new_size[1]; out_w = new_size[1];
} else { } else {
auto scale_tensor = ctx.Input<Tensor>("Scale"); if (scale_tensor) {
auto scale = ctx.Attr<std::vector<float>>("scale"); auto scale_data =
if (scale_tensor != nullptr) { funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) { if (scale_data.size() > 1) {
scale_h = scale_data[0]; scale_h = scale_data[0];
scale_w = scale_data[1]; scale_w = scale_data[1];
...@@ -974,14 +639,16 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -974,14 +639,16 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
scale_w = scale_data[0]; scale_w = scale_data[0];
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_w > 0, true, scale_w > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) " "The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_w)); scale_w));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_h > 0, true, scale_h > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) " "The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_h)); scale_h));
...@@ -991,14 +658,16 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -991,14 +658,16 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
scale_w = scale[1]; scale_w = scale[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_w > 0, true, scale_w > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) " "The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_w)); scale_w));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_h > 0, true, scale_h > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) " "The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_h)); scale_h));
...@@ -1008,29 +677,34 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -1008,29 +677,34 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
out_h = static_cast<int>(in_h * scale_h); out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w); out_w = static_cast<int>(in_w * scale_w);
} }
auto out_size = ctx.Input<Tensor>("OutSize"); if (out_size) {
if (out_size != nullptr) { auto out_size_data =
auto out_size_data = get_new_data_from_tensor<int>(out_size); funcs::get_new_data_from_tensor<int>(out_size.get_ptr());
out_h = out_size_data[0]; out_h = out_size_data[0];
out_w = out_size_data[1]; out_w = out_size_data[1];
} }
} }
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(
"out_h in Attr(out_shape) of Op(interpolate) " out_h,
"should be greater than 0.")); 0,
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( errors::InvalidArgument("out_h in Attr(out_shape) of Op(interpolate) "
"out_w in Attr(out_shape) of Op(interpolate) " "should be greater than 0."));
"should be greater than 0.")); PADDLE_ENFORCE_GT(
framework::DDim dim_out; out_w,
0,
errors::InvalidArgument("out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_h, out_w}; dim_out = {n, c, out_h, out_w};
} else { } else {
dim_out = {n, out_h, out_w, c}; dim_out = {n, out_h, out_w, c};
} }
output->mutable_data<T>(dim_out, ctx.GetPlace()); output->Resize(dim_out);
dev_ctx.template Alloc<T>(output);
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output); paddle::framework::TensorCopy(x, dev_ctx.GetPlace(), output);
return; return;
} }
...@@ -1052,50 +726,81 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -1052,50 +726,81 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
} }
if ("bilinear" == interp_method) { if ("bilinear" == interp_method) {
BilinearInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c, BilinearInterpolation<T>(x,
out_h, out_w, align_corners, align_mode, output,
ratio_h,
ratio_w,
in_h,
in_w,
n,
c,
out_h,
out_w,
align_corners,
align_mode,
data_layout); data_layout);
} else if ("nearest" == interp_method) { } else if ("nearest" == interp_method) {
NearestNeighborInterpolate<T>(input, output, ratio_h, ratio_w, n, c, out_h, NearestNeighborInterpolate<T>(x,
out_w, align_corners, data_layout); output,
ratio_h,
ratio_w,
n,
c,
out_h,
out_w,
align_corners,
data_layout);
} else if ("bicubic" == interp_method) { } else if ("bicubic" == interp_method) {
BicubicInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c, BicubicInterpolation<T>(x,
out_h, out_w, align_corners, data_layout); output,
ratio_h,
ratio_w,
in_h,
in_w,
n,
c,
out_h,
out_w,
align_corners,
data_layout);
} }
} }
template <typename T> template <typename T, typename Context>
static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, static void Interpolate3DCPUFwd(
const Tensor& input, Tensor* output) { const Context& dev_ctx,
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const DenseTensor& x,
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w; int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_d = -1; float scale_d = -1;
float scale_h = -1; float scale_h = -1;
float scale_w = -1; float scale_w = -1;
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor"); if (size_tensor && size_tensor->size() > 0) {
if (list_new_size_tensor.size() > 0) {
// have size tensor // have size tensor
auto new_size = get_new_shape(list_new_size_tensor); auto new_size = funcs::get_new_shape(size_tensor.get());
out_d = new_size[0]; out_d = new_size[0];
out_h = new_size[1]; out_h = new_size[1];
out_w = new_size[2]; out_w = new_size[2];
} else { } else {
auto scale_tensor = ctx.Input<Tensor>("Scale"); if (scale_tensor) {
auto scale = ctx.Attr<std::vector<float>>("scale"); auto scale_data =
if (scale_tensor != nullptr) { funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) { if (scale_data.size() > 1) {
scale_d = scale_data[0]; scale_d = scale_data[0];
scale_h = scale_data[1]; scale_h = scale_data[1];
...@@ -1106,20 +811,23 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -1106,20 +811,23 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
scale_w = scale_data[0]; scale_w = scale_data[0];
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_w > 0, true, scale_w > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) " "The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_w)); scale_w));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_h > 0, true, scale_h > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) " "The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_h)); scale_h));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_d > 0, true, scale_d > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_d in input 'Scale' Tensor of Operator(interpolate) " "The scale_d in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_d)); scale_d));
...@@ -1130,20 +838,23 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -1130,20 +838,23 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
scale_w = scale[2]; scale_w = scale[2];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_w > 0, true, scale_w > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) " "The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_w)); scale_w));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_h > 0, true, scale_h > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) " "The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_h)); scale_h));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale_d > 0, true, scale_d > 0,
platform::errors::InvalidArgument( true,
errors::InvalidArgument(
"The scale_d in Attr(scale) of Operator(interpolate) " "The scale_d in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.", "should be greater than 0, but received value is %d.",
scale_d)); scale_d));
...@@ -1154,35 +865,42 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -1154,35 +865,42 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
out_h = static_cast<int>(in_h * scale_h); out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w); out_w = static_cast<int>(in_w * scale_w);
} }
auto out_size = ctx.Input<Tensor>("OutSize"); if (out_size) {
if (out_size != nullptr) { auto out_size_data =
auto out_size_data = get_new_data_from_tensor<int>(out_size); funcs::get_new_data_from_tensor<int>(out_size.get_ptr());
out_d = out_size_data[0]; out_d = out_size_data[0];
out_h = out_size_data[1]; out_h = out_size_data[1];
out_w = out_size_data[2]; out_w = out_size_data[2];
} }
} }
PADDLE_ENFORCE_GT(out_d, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(
"out_d in Attr(out_shape) of Op(interpolate) " out_d,
"should be greater than 0.")); 0,
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument( errors::InvalidArgument("out_d in Attr(out_shape) of Op(interpolate) "
"out_h in Attr(out_shape) of Op(interpolate) " "should be greater than 0."));
"should be greater than 0.")); PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( out_h,
"out_w in Attr(out_shape) of Op(interpolate) " 0,
"should be greater than 0.")); errors::InvalidArgument("out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out; PADDLE_ENFORCE_GT(
out_w,
0,
errors::InvalidArgument("out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) { if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_d, out_h, out_w}; dim_out = {n, c, out_d, out_h, out_w};
} else { } else {
dim_out = {n, out_d, out_h, out_w, c}; dim_out = {n, out_d, out_h, out_w, c};
} }
output->mutable_data<T>(dim_out, ctx.GetPlace()); output->Resize(dim_out);
dev_ctx.template Alloc<T>(output);
if (in_d == out_d && in_h == out_h && in_w == out_w) { if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output); paddle::framework::TensorCopy(x, dev_ctx.GetPlace(), output);
return; return;
} }
...@@ -1212,407 +930,296 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -1212,407 +930,296 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
} }
if ("trilinear" == interp_method) { if ("trilinear" == interp_method) {
TrilinearInterpolation<T>(input, output, ratio_d, ratio_h, ratio_w, in_d, TrilinearInterpolation<T>(x,
in_h, in_w, n, c, out_d, out_h, out_w, output,
align_corners, align_mode, data_layout); ratio_d,
ratio_h,
ratio_w,
in_d,
in_h,
in_w,
n,
c,
out_d,
out_h,
out_w,
align_corners,
align_mode,
data_layout);
} else if ("nearest" == interp_method) { } else if ("nearest" == interp_method) {
NearestNeighbor3DInterpolate<T>(input, output, ratio_d, ratio_h, ratio_w, n, NearestNeighbor3DInterpolate<T>(x,
c, out_d, out_h, out_w, align_corners, output,
ratio_d,
ratio_h,
ratio_w,
n,
c,
out_d,
out_h,
out_w,
align_corners,
data_layout); data_layout);
} }
} }
template <typename T> template <typename T, typename Context>
static void Interpolate1DCPUBwd(const framework::ExecutionContext& ctx, void InterpolateKernel(
Tensor* input_grad, const Tensor& output_grad) { const Context& ctx,
auto* input = ctx.Input<Tensor>("X"); const DenseTensor& x,
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); paddle::optional<const DenseTensor&> out_size,
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
int n, c, in_d, in_h, in_w; paddle::optional<const DenseTensor&> scale_tensor,
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); const std::string& data_layout,
int out_d,
auto interp_method = ctx.Attr<std::string>("interp_method"); int out_h,
bool align_corners = ctx.Attr<bool>("align_corners"); int out_w,
int align_mode = ctx.Attr<int>("align_mode"); const std::vector<float>& scale,
const std::string& interp_method,
int out_w = ctx.Attr<int>("out_w"); bool align_corners,
float scale_w = -1.0; int align_mode,
auto scale_tensor = ctx.Input<Tensor>("Scale"); DenseTensor* output) {
auto scale = ctx.Attr<std::vector<float>>("scale"); auto input_dims = x.dims();
if (scale_tensor != nullptr) { if (input_dims.size() == 3) { // 1D interpolation
auto scale_data = get_new_data_from_tensor<float>(scale_tensor); Interpolate1DCPUFwd<T, Context>(ctx,
scale_w = scale_data[0]; x,
PADDLE_ENFORCE_EQ( out_size,
scale_w > 0, true, size_tensor,
platform::errors::InvalidArgument( scale_tensor,
"The scale_w in input 'Scale' Tensor of Operator(interpolate) " data_layout,
"should be greater than 0, but received value is %d.", out_w,
scale_w)); scale,
} else { interp_method,
if (scale.size() > 0) { align_corners,
scale_w = scale[0]; align_mode,
PADDLE_ENFORCE_EQ( output);
scale_w > 0, true,
platform::errors::InvalidArgument( } else if (input_dims.size() == 4) { // 2D interpolation
"The scale_w in Attr(scale) of Operator(interpolate) " Interpolate2DCPUFwd<T>(ctx,
"should be greater than 0, but received value is %d.", x,
scale_w)); out_size,
} size_tensor,
} scale_tensor,
if (scale_w > 0.) { data_layout,
out_w = static_cast<int>(in_w * scale_w); out_h,
} out_w,
auto out_size = ctx.Input<Tensor>("OutSize"); scale,
if (out_size != nullptr) { interp_method,
auto out_size_data = get_new_data_from_tensor<int>(out_size); align_corners,
out_w = out_size_data[0]; align_mode,
} output);
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor"); } else if (input_dims.size() == 5) { // 3D interpolation
if (list_new_size_tensor.size() > 0) { Interpolate3DCPUFwd<T>(ctx,
// have size tensor x,
auto new_size = get_new_shape(list_new_size_tensor); out_size,
out_w = new_size[0]; size_tensor,
} scale_tensor,
data_layout,
framework::DDim dim_grad; out_d,
if (data_layout == DataLayout::kNCHW) { out_h,
dim_grad = {n, c, in_w}; out_w,
} else { scale,
dim_grad = {n, in_w, c}; interp_method,
} align_corners,
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace()); align_mode,
output);
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
if ("linear" == interp_method) {
LinearInterpolationGrad<T>(output_grad, input_grad, ratio_w, in_w, n, c,
out_w, align_corners, align_mode, data_layout);
} }
} }
template <typename T> template <typename T, typename Context>
static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx, void BilinearInterpKernel(
Tensor* input_grad, const Tensor& output_grad) { const Context& ctx,
auto* input = ctx.Input<Tensor>("X"); const DenseTensor& x,
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); paddle::optional<const DenseTensor&> out_size,
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
int n, c, in_d, in_h, in_w; paddle::optional<const DenseTensor&> scale_tensor,
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); const std::string& data_layout,
int out_d,
auto interp_method = ctx.Attr<std::string>("interp_method"); int out_h,
bool align_corners = ctx.Attr<bool>("align_corners"); int out_w,
int align_mode = ctx.Attr<int>("align_mode"); const std::vector<float>& scale,
const std::string& interp_method,
int out_h = ctx.Attr<int>("out_h"); bool align_corners,
int out_w = ctx.Attr<int>("out_w"); int align_mode,
float scale_h = -1; DenseTensor* output) {
float scale_w = -1; InterpolateKernel<T, Context>(ctx,
auto scale_tensor = ctx.Input<Tensor>("Scale"); x,
auto scale = ctx.Attr<std::vector<float>>("scale"); out_size,
if (scale_tensor != nullptr) { size_tensor,
auto scale_data = get_new_data_from_tensor<float>(scale_tensor); scale_tensor,
if (scale_data.size() > 1) { data_layout,
scale_h = scale_data[0]; out_d,
scale_w = scale_data[1]; out_h,
} else { out_w,
scale_w = scale_data[0]; scale,
scale_h = scale_data[0]; interp_method,
} align_corners,
PADDLE_ENFORCE_EQ( align_mode,
scale_w > 0, true, output);
platform::errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
} else {
if (scale.size() > 1) {
scale_h = scale[0];
scale_w = scale[1];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
}
}
if (scale_h > 0. && scale_w > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_h = new_size[0];
out_w = new_size[1];
}
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_h, in_w};
} else {
dim_grad = {n, in_h, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
if ("bilinear" == interp_method) {
BilinearInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
in_h, in_w, n, c, out_h, out_w, align_corners,
align_mode, data_layout);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolateGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
n, c, out_h, out_w, align_corners,
data_layout);
} else if ("bicubic" == interp_method) {
BicubicInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w, in_h,
in_w, n, c, out_h, out_w, align_corners,
data_layout);
}
} }
template <typename T> template <typename T, typename Context>
static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx, void NearestInterpKernel(
Tensor* input_grad, const Tensor output_grad) { const Context& ctx,
auto* input = ctx.Input<Tensor>("X"); const DenseTensor& x,
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); paddle::optional<const DenseTensor&> out_size,
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
int n, c, in_d, in_h, in_w; paddle::optional<const DenseTensor&> scale_tensor,
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); const std::string& data_layout,
int out_d,
auto interp_method = ctx.Attr<std::string>("interp_method"); int out_h,
bool align_corners = ctx.Attr<bool>("align_corners"); int out_w,
int align_mode = ctx.Attr<int>("align_mode"); const std::vector<float>& scale,
const std::string& interp_method,
int out_d = ctx.Attr<int>("out_d"); bool align_corners,
int out_h = ctx.Attr<int>("out_h"); int align_mode,
int out_w = ctx.Attr<int>("out_w"); DenseTensor* output) {
float scale_d = -1; InterpolateKernel<T, Context>(ctx,
float scale_h = -1; x,
float scale_w = -1; out_size,
auto scale_tensor = ctx.Input<Tensor>("Scale"); size_tensor,
auto scale = ctx.Attr<std::vector<float>>("scale"); scale_tensor,
if (scale_tensor != nullptr) { data_layout,
auto scale_data = get_new_data_from_tensor<float>(scale_tensor); out_d,
if (scale_data.size() > 1) { out_h,
scale_d = scale_data[0]; out_w,
scale_h = scale_data[1]; scale,
scale_w = scale_data[2]; interp_method,
} else { align_corners,
scale_d = scale_data[0]; align_mode,
scale_h = scale_data[0]; output);
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0, true,
platform::errors::InvalidArgument(
"The scale_d in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
} else {
if (scale.size() > 1) {
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0, true,
platform::errors::InvalidArgument(
"The scale_d in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
}
}
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
out_d = static_cast<int>(in_d * scale_d);
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_d = out_size_data[0];
out_h = out_size_data[1];
out_w = out_size_data[2];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
}
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_d, in_h, in_w};
} else {
dim_grad = {n, in_d, in_h, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
phi::funcs::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
float new_scale_d = 0.f;
new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
: static_cast<float>(in_d) / out_d;
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(new_scale_d);
}
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
if ("trilinear" == interp_method) {
TrilinearInterpolationGrad<T>(
output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n,
c, out_d, out_h, out_w, align_corners, align_mode, data_layout);
} else if ("nearest" == interp_method) {
NearestNeighbor3DInterpolateGrad<T>(output_grad, input_grad, ratio_d,
ratio_h, ratio_w, n, c, out_d, out_h,
out_w, align_corners, data_layout);
}
} }
template <typename T> template <typename T, typename Context>
class InterpolateV2Kernel : public framework::OpKernel<T> { void TrilinearInterpKernel(
public: const Context& ctx,
void Compute(const framework::ExecutionContext& ctx) const override { const DenseTensor& x,
auto* input = ctx.Input<Tensor>("X"); paddle::optional<const DenseTensor&> out_size,
auto* output = ctx.Output<Tensor>("Out"); paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
auto input_dims = input->dims(); const std::string& data_layout,
if (input_dims.size() == 3) { // 1D interpolation int out_d,
Interpolate1DCPUFwd<T>(ctx, *input, output); int out_h,
} else if (input_dims.size() == 4) { // 2D interpolation int out_w,
Interpolate2DCPUFwd<T>(ctx, *input, output); const std::vector<float>& scale,
} else if (input_dims.size() == 5) { // 3D interpolation const std::string& interp_method,
Interpolate3DCPUFwd<T>(ctx, *input, output); bool align_corners,
} int align_mode,
} DenseTensor* output) {
}; InterpolateKernel<T, Context>(ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
}
template <typename T> template <typename T, typename Context>
class InterpolateV2GradKernel : public framework::OpKernel<T> { void LinearInterpKernel(
public: const Context& ctx,
void Compute(const framework::ExecutionContext& ctx) const override { const DenseTensor& x,
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); paddle::optional<const DenseTensor&> out_size,
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
InterpolateKernel<T, Context>(ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
}
auto output_grad_dims = output_grad->dims(); template <typename T, typename Context>
if (output_grad_dims.size() == 3) { // 1D interpolation grad void BicubicInterpKernel(
Interpolate1DCPUBwd<T>(ctx, input_grad, *output_grad); const Context& ctx,
} else if (output_grad_dims.size() == 4) { // 2D interpolation grad const DenseTensor& x,
Interpolate2DCPUBwd<T>(ctx, input_grad, *output_grad); paddle::optional<const DenseTensor&> out_size,
} else if (output_grad_dims.size() == 5) { // 3D interpolation grad paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
Interpolate3DCPUBwd<T>(ctx, input_grad, *output_grad); paddle::optional<const DenseTensor&> scale_tensor,
} const std::string& data_layout,
} int out_d,
}; int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
InterpolateKernel<T, Context>(ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
}
} // namespace operators } // namespace phi
} // namespace paddle
PD_REGISTER_KERNEL(bilinear_interp_v2,
CPU,
ALL_LAYOUT,
phi::BilinearInterpKernel,
float,
double,
uint8_t) {}
PD_REGISTER_KERNEL(nearest_interp_v2,
CPU,
ALL_LAYOUT,
phi::NearestInterpKernel,
float,
double,
int,
int64_t,
uint8_t) {}
PD_REGISTER_KERNEL(trilinear_interp_v2,
CPU,
ALL_LAYOUT,
phi::TrilinearInterpKernel,
float,
double,
uint8_t) {}
PD_REGISTER_KERNEL(linear_interp_v2,
CPU,
ALL_LAYOUT,
phi::LinearInterpKernel,
float,
double,
uint8_t) {}
PD_REGISTER_KERNEL(bicubic_interp_v2,
CPU,
ALL_LAYOUT,
phi::BicubicInterpKernel,
float,
double) {}
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
namespace phi { namespace phi {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/platform/fast_divmod.h"
#endif
namespace phi {
namespace funcs {
template <typename T>
HOSTDEVICE inline T CubicConvolution1(T x, T A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
}
template <typename T>
HOSTDEVICE inline T CubicConvolution2(T x, T A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
template <typename T>
HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) {
T A = -0.75;
T x1 = t;
coeffs[0] = CubicConvolution2<T>(x1 + 1.0, A);
coeffs[1] = CubicConvolution1<T>(x1, A);
// opposite coefficients
T x2 = 1.0 - t;
coeffs[2] = CubicConvolution1<T>(x2, A);
coeffs[3] = CubicConvolution2<T>(x2 + 1.0, A);
}
inline void ExtractNCDWH(const DDim& dims,
const DataLayout& data_layout,
int* N,
int* C,
int* D,
int* H,
int* W) {
*N = dims[0];
if (dims.size() == 3) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[2];
*D = 1;
*H = 1;
*W = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
} else if (dims.size() == 4) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[3];
*D = 1;
*H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*W = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
} else {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[4];
*D = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*H = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
*W = data_layout == DataLayout::kNCHW ? dims[4] : dims[3];
}
}
inline std::vector<int> get_new_shape(
const std::vector<const DenseTensor*>& list_new_shape_tensor) {
// get tensor from
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(
tensor->dims(),
phi::make_ddim({1}),
errors::InvalidArgument("The shape of dimension tensor should be [1],"
"but received d%.",
tensor->dims()));
if (paddle::platform::is_gpu_place(tensor->place())) {
DenseTensor temp;
paddle::framework::TensorCopySync(
*tensor, paddle::platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
}
return vec_new_shape;
}
template <typename T>
inline std::vector<T> get_new_data_from_tensor(
const DenseTensor* new_data_tensor) {
std::vector<T> vec_new_data;
auto* new_data = new_data_tensor->data<T>();
DenseTensor cpu_starts_tensor;
if (paddle::platform::is_gpu_place(new_data_tensor->place())) {
paddle::framework::TensorCopySync(
*new_data_tensor, paddle::platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#ifdef PADDLE_WITH_ASCEND_CL
if (paddle::platform::is_npu_place(new_data_tensor->place())) {
paddle::framework::TensorCopySync(
*new_data_tensor, paddle::platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#endif
#ifdef PADDLE_WITH_XPU
if (paddle::platform::is_xpu_place(new_data_tensor->place())) {
paddle::framework::TensorCopySync(
*new_data_tensor, paddle::platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#endif
vec_new_data = std::vector<T>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
#if defined(__NVCC__) || defined(__HIPCC__)
using paddle::platform::FastDivMod;
struct FastDivModForInterpolate {
public:
FastDivMod channels_div;
FastDivMod output_w_div;
FastDivMod output_wc_div;
explicit HOSTDEVICE FastDivModForInterpolate(const int channels,
const int output_w,
const int outout_wc)
: channels_div(FastDivMod(channels)),
output_w_div(FastDivMod(output_w)),
output_wc_div(FastDivMod(outout_wc)) {}
};
#endif
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/interpolate_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
__forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
int* in_img_idx,
int* x_id,
T* lambda1,
T* lambda2,
T src_x,
const int in_img_x) {
src_x = (src_x > 0) ? src_x : 0.f;
*in_img_idx = static_cast<int>(src_x);
*x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0;
*lambda1 = src_x - *in_img_idx;
*lambda2 = 1.f - *lambda1;
}
template <typename T>
__global__ void KeLinearInterpBw(T* in,
const size_t in_img_w,
const size_t input_w,
const T* out,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const T ratio_w,
const bool align_corners,
const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idx = tid % out_img_w;
} else {
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5
: ratio_w * out_img_idx;
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size + in_img_idx];
} else {
in_pos = &in[out_id_h * input_w + in_img_idx * num_channels + channel_id];
}
const T* out_pos = &out[out_id_w];
if (data_layout == DataLayout::kNCHW) {
paddle::platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(&in_pos[w_id], w1lambda * out_pos[0]);
} else {
paddle::platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
w1lambda * out_pos[0]);
}
}
}
template <typename T>
__global__ void KeNearestNeighborInterpNCHWBw(T* in,
const size_t in_img_h,
const size_t in_img_w,
const T* out,
const size_t out_img_h,
const size_t out_img_w,
const size_t nc,
const float ratio_h,
const float ratio_w,
const bool align_corners) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;
// nearest_sampling by multiple read in_addr and write to out_addr
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;
int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;
// prevent from multiple threads writing
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
T* in_pos = &in[in_index];
const T out_pos = out[out_index];
paddle::platform::CudaAtomicAdd(in_pos, out_pos);
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}
}
template <typename T>
__global__ void KeNearestNeighborInterpBw(
T* in,
const size_t in_img_h,
const size_t in_img_w,
const size_t input_h,
const size_t input_w,
const T* out,
const size_t out_img_h,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const float ratio_h,
const float ratio_w,
const bool align_corners,
funcs::FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_img_size = in_img_h * in_img_w;
int out_img_size = out_img_h * out_img_w;
for (; tid < nthreads; tid += stride) {
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
T* in_pos = &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
const T out_pos = out[tid];
paddle::platform::CudaAtomicAdd(in_pos, out_pos);
}
}
/* Calculate the minimum of partial elements in a block */
template <typename T>
__inline__ __device__ T PartialBlockMin(T val,
size_t threads_num_in_block,
unsigned mask) {
__shared__ T shared[WARP_SIZE];
__shared__ T shared_last_val;
__shared__ int shared_last_idx;
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
int threshold = (threads_num_in_block & (-WARP_SIZE));
if (threadIdx.x < threshold) {
shared_last_idx = (threshold >> 5) - 1;
val = phi::funcs::warpReduceMin(val, mask);
if (lane == 0) {
shared[wid] = val;
}
} else {
shared_last_val = std::numeric_limits<T>::max();
paddle::platform::CudaAtomicMin(&shared_last_val, val);
shared[wid] = shared_last_val;
shared_last_idx = wid;
}
__syncthreads();
if (threadIdx.x < threshold) {
val = (lane <= shared_last_idx) ? shared[lane]
: std::numeric_limits<T>::max();
val = phi::funcs::warpReduceMin(val, mask);
shared_last_val = val;
}
__syncthreads();
if (threadIdx.x >= threshold) {
val = shared_last_val;
}
return val;
}
template <typename T>
__global__ void KeBilinearInterpBwShareMemory(T* in,
const int in_h,
const int in_w,
const T* __restrict__ out,
const int out_h,
const int out_w,
const int n,
const int num_channels,
float ratio_h,
float ratio_w,
const T align_type_value,
bool is_nchw) {
__shared__ T s_data[2][1024];
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_chw = in_h * in_w * num_channels;
int out_chw = num_channels * out_h * out_w;
int nthreads = n * out_chw;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / out_chw;
int out_id_w = tid % out_chw;
const int in_img_size = in_h * in_w;
const int out_img_size = out_h * out_w;
T value = out[out_id_h * out_chw + out_id_w];
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_w;
int out_img_idx = tid % out_w;
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(
&in_img_idy, &h_id, &h1lambda, &h2lambda, src_h, in_h);
// top_left_index is just input_index.
int input_index = out_id_h * in_chw + channel_id * in_img_size +
in_img_idy * in_w + in_img_idx;
int top_right_index = input_index + w_id;
int bot_left_index = input_index + h_id * in_w;
int bot_right_index = input_index + h_id * in_w + w_id;
int in_top_min_index, in_bot_min_index;
s_data[0][threadIdx.x] = 0.f;
s_data[1][threadIdx.x] = 0.f;
int remain = nthreads - (tid & (-blockDim.x));
int in_top_max_index =
phi::funcs::blockReduceMax(top_right_index, FINAL_MASK);
int in_bot_max_index =
phi::funcs::blockReduceMax(bot_right_index, FINAL_MASK);
if (remain > blockDim.x) {
in_top_min_index = phi::funcs::blockReduceMin(input_index, FINAL_MASK);
in_bot_min_index = phi::funcs::blockReduceMin(bot_left_index, FINAL_MASK);
} else {
in_top_min_index = PartialBlockMin(input_index, remain, FINAL_MASK);
in_bot_min_index = PartialBlockMin(bot_left_index, remain, FINAL_MASK);
}
int upper_limit_share_idx = (in_top_max_index - in_top_min_index) >
(in_bot_max_index - in_bot_min_index)
? (in_top_max_index - in_top_min_index)
: (in_bot_max_index - in_bot_min_index);
if (h_id != 0) {
paddle::platform::CudaAtomicAdd(
&s_data[0][input_index - in_top_min_index],
h2lambda * w2lambda * value);
paddle::platform::CudaAtomicAdd(
&s_data[0][top_right_index - in_top_min_index],
h2lambda * w1lambda * value);
paddle::platform::CudaAtomicAdd(
&s_data[1][bot_left_index - in_bot_min_index],
h1lambda * w2lambda * value);
paddle::platform::CudaAtomicAdd(
&s_data[1][bot_right_index - in_bot_min_index],
h1lambda * w1lambda * value);
} else {
paddle::platform::CudaAtomicAdd(
&s_data[0][top_right_index - in_top_min_index],
(h2lambda + h1lambda) * w1lambda * value);
paddle::platform::CudaAtomicAdd(
&s_data[1][bot_left_index - in_bot_min_index],
(h1lambda + h2lambda) * w2lambda * value);
}
__syncthreads();
if (threadIdx.x <= upper_limit_share_idx) {
paddle::platform::CudaAtomicAdd(&in[in_top_min_index + threadIdx.x],
s_data[0][threadIdx.x]);
paddle::platform::CudaAtomicAdd(&in[in_bot_min_index + threadIdx.x],
s_data[1][threadIdx.x]);
}
}
}
__device__ __forceinline__ int GetInputIndex(const size_t nc,
const int height,
const int width,
const int h,
const int w) {
return (nc * height + h) * width + w;
}
template <typename T>
__global__ void KeBilinearInterpNCHWBw(T* in,
const int in_h,
const int in_w,
const int out_h,
const int out_w,
const int n,
const int num_channels,
float ratio_h,
float ratio_w,
const T* __restrict__ out,
const T align_type_value) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
int stride = blockDim.x * gridDim.x;
int num_out = n * num_channels * out_h * out_w;
int num_in = n * num_channels * in_h * in_w;
for (; index < num_out; index += stride) {
int index_tmp = index;
int w2 = index_tmp % out_w;
index_tmp /= out_w;
int h2 = index_tmp % out_h;
int nc = index_tmp / out_h;
int h1, y_id;
T h1lambda, h0lambda;
T src_y = ratio_h * (h2 + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(
&h1, &y_id, &h1lambda, &h0lambda, src_y, in_h);
int w1, x_id;
T w1lambda, w0lambda;
T src_x = ratio_w * (w2 + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(
&w1, &x_id, &w1lambda, &w0lambda, src_x, in_w);
T d2val = out[index];
paddle::platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1),
h0lambda * w0lambda * d2val);
paddle::platform::CudaAtomicAdd(
in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id),
h0lambda * w1lambda * d2val);
paddle::platform::CudaAtomicAdd(
in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1),
h1lambda * w0lambda * d2val);
paddle::platform::CudaAtomicAdd(
in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id),
h1lambda * w1lambda * d2val);
}
}
template <typename T>
__global__ void KeBilinearInterpBw(T* in,
const int in_h,
const int in_w,
const T* __restrict__ out,
const int out_h,
const int out_w,
const int n,
const int out_chw,
const int num_channels,
float ratio_h,
float ratio_w,
const T align_type_value,
funcs::FastDivModForInterpolate divmods) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_chw = in_h * in_w * num_channels;
int nthreads = n * out_chw;
for (; tid < nthreads; tid += stride) {
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(
&in_img_idy, &h_id, &h1lambda, &h2lambda, src_h, in_h);
T value = out[tid];
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
in_img_idx * num_channels + channel_id];
paddle::platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
paddle::platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
h2lambda * w1lambda * value);
paddle::platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels],
h1lambda * w2lambda * value);
paddle::platform::CudaAtomicAdd(
&in_pos[h_id * in_w * num_channels + w_id * num_channels],
h1lambda * w1lambda * value);
}
}
template <typename T>
__global__ void KeBicubicInterpBw(T* in,
const size_t in_img_h,
const size_t in_img_w,
const size_t input_h,
const size_t input_w,
const T* out,
const size_t out_img_h,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const float ratio_h,
const float ratio_w,
const bool align_corners,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
T in_img_idy = align_corners
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y;
T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx);
const T x_t = in_img_idx - input_x;
T x_coeffs[4];
T y_coeffs[4];
funcs::get_cubic_upsample_coefficients(x_coeffs, x_t);
funcs::get_cubic_upsample_coefficients(y_coeffs, y_t);
const T* out_pos = &out[out_id_h * output_w + out_id_w];
T* in_pos;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
int access_y = max(min(static_cast<int>(input_y - 1 + j),
static_cast<int>(in_img_h - 1)),
0);
int access_x = max(min(static_cast<int>(input_x - 1 + i),
static_cast<int>(in_img_w - 1)),
0);
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x];
} else {
in_pos = &in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x * num_channels + channel_id];
}
paddle::platform::CudaAtomicAdd(
&in_pos[0], (out_pos[0] * y_coeffs[j] * x_coeffs[i]));
}
}
}
}
template <typename T>
__global__ void KeTrilinearInterpBw(T* in,
const size_t in_img_d,
const size_t in_img_h,
const size_t in_img_w,
const size_t input_h,
const size_t input_w,
const T* out,
const size_t out_img_d,
const size_t out_img_h,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const T ratio_d,
const T ratio_h,
const T ratio_w,
const bool align_corners,
const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = align_flag
? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
: static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
(in_img_idt * in_img_h + in_img_idy) * in_img_w +
in_img_idx;
T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
T* in_pos2 = &in[in_pos2_idx];
const T* out_pos = &out[out_id_h * output_w + out_id_w];
// trilinear interpolation grad
paddle::platform::CudaAtomicAdd(
&in_pos1[0], d2lambda * h2lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[w_id], d2lambda * h2lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[h_id * in_img_w],
d2lambda * h1lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[h_id * in_img_w + w_id],
d2lambda * h1lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[0], d1lambda * h2lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[w_id], d1lambda * h2lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[h_id * in_img_w],
d1lambda * h1lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[h_id * in_img_w + w_id],
d1lambda * h1lambda * w1lambda * out_pos[0]);
} else {
int in_pos1_idx = out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id;
T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w * num_channels;
T* in_pos2 = &in[in_pos2_idx];
const T* out_pos = &out[out_id_h * output_w + out_id_w];
// trilinear interpolation grad
paddle::platform::CudaAtomicAdd(
&in_pos1[0], d2lambda * h2lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[w_id * num_channels],
d2lambda * h2lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[h_id * in_img_w * num_channels],
d2lambda * h1lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos1[h_id * in_img_w * num_channels + w_id * num_channels],
d2lambda * h1lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[0], d1lambda * h2lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[w_id * num_channels],
d1lambda * h2lambda * w1lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[h_id * in_img_w * num_channels],
d1lambda * h1lambda * w2lambda * out_pos[0]);
paddle::platform::CudaAtomicAdd(
&in_pos2[h_id * in_img_w * num_channels + w_id * num_channels],
d1lambda * h1lambda * w1lambda * out_pos[0]);
}
}
}
template <typename T>
__global__ void KeNearestNeighbor3DInterpBw(T* in,
const size_t in_img_d,
const size_t in_img_h,
const size_t in_img_w,
const size_t input_h,
const size_t input_w,
const T* out,
const size_t out_img_d,
const size_t out_img_h,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const float ratio_d,
const float ratio_h,
const float ratio_w,
const bool align_corners,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = (align_corners)
? static_cast<int>(ratio_d * out_img_idt + 0.5)
: static_cast<int>(ratio_d * out_img_idt);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w +
in_img_idx];
} else {
in_pos = &in[out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
const T out_pos = out[out_id_h * output_w + out_id_w];
paddle::platform::CudaAtomicAdd(in_pos, out_pos);
}
}
template <typename T, typename Context>
static void Interpolate1DCUDABwd(
const Context& dev_ctx,
const DenseTensor& input,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& output_grad,
const std::string& data_layout_str,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* input_grad) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
float scale_w = -1;
if (scale_tensor) {
auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
} else {
if (scale.size() > 0) {
scale_w = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
}
}
if (scale_w > 0.) {
out_w = static_cast<int>(in_w * scale_w);
}
if (out_size) {
DenseTensor sizes;
paddle::framework::TensorCopySync(
*out_size, paddle::platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_w = size_data[0];
}
if (size_tensor && size_tensor->size() > 0) {
// have size tensor
auto new_size = funcs::get_new_shape(size_tensor.get());
out_w = new_size[0];
}
auto* output_grad_data = output_grad.data<T>();
phi::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_w};
} else {
dim_grad = {n, in_w, c};
}
input_grad->Resize(dim_grad);
auto* input_grad_data = dev_ctx.template Alloc<T>(input_grad);
phi::funcs::SetConstant<Context, T> zero;
zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
int64_t in_cw = c * in_w;
int64_t out_cw = c * out_w;
auto pixelNum = n * out_cw;
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, pixelNum);
if ("linear" == interp_method) {
KeLinearInterpBw<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_grad_data,
in_w,
in_cw,
output_grad_data,
out_w,
n,
out_cw,
c,
ratio_w,
align_corners,
align_mode,
data_layout);
}
}
template <typename T, typename Context>
static void Interpolate2DCUDABwd(
const Context& dev_ctx,
const DenseTensor& input,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& output_grad,
const std::string& data_layout_str,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* input_grad) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
float scale_h = -1;
float scale_w = -1;
if (scale_tensor) {
auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
} else {
if (scale.size() > 1) {
scale_w = scale[1];
scale_h = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
}
}
if (scale_w > 0. && scale_h > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
if (out_size) {
DenseTensor sizes;
paddle::framework::TensorCopySync(
*out_size, paddle::platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
if (size_tensor && size_tensor->size() > 0) {
// have size tensor
auto new_size = funcs::get_new_shape(size_tensor.get());
out_h = new_size[0];
out_w = new_size[1];
}
auto* output_grad_data = output_grad.data<T>();
phi::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_h, in_w};
} else {
dim_grad = {n, in_h, in_w, c};
}
input_grad->Resize(dim_grad);
auto* input_grad_data = dev_ctx.template Alloc<T>(input_grad);
phi::funcs::SetConstant<Context, T> zero;
zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
int64_t in_hw = in_h * in_w;
int64_t out_hw = out_h * out_w;
int64_t in_chw = c * in_hw;
int64_t out_chw = c * out_hw;
auto pixelNum = n * out_chw;
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, pixelNum);
if ("nearest" == interp_method) {
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
int nc = n * c;
backends::gpu::GpuLaunchConfig config_3d =
backends::gpu::GetGpuLaunchConfig3D(dev_ctx, nc, out_h, out_w);
KeNearestNeighborInterpNCHWBw<T><<<config_3d.block_per_grid,
config_3d.thread_per_block,
0,
dev_ctx.stream()>>>(input_grad_data,
in_h,
in_w,
output_grad_data,
out_h,
out_w,
nc,
ratio_h,
ratio_w,
align_corners);
} else {
int64_t cw = c * out_w;
auto interp_divmods = funcs::FastDivModForInterpolate(c, out_chw, cw);
KeNearestNeighborInterpBw<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_grad_data,
in_h,
in_w,
n,
in_chw,
output_grad_data,
out_h,
out_w,
n,
out_chw,
c,
ratio_h,
ratio_w,
align_corners,
interp_divmods);
}
} else if ("bilinear" == interp_method) {
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false;
bool optimize_flag = false;
#ifndef __HIPCC__
optimize_flag = (in_h < (out_h >> 6) && in_w < (out_w >> 6))
? true
: ((in_h == 1 && in_w == 1) ? true : false);
#endif
if (optimize_flag & is_nchw) {
KeBilinearInterpBwShareMemory<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_grad_data,
in_h,
in_w,
output_grad_data,
out_h,
out_w,
n,
c,
ratio_h,
ratio_w,
align_type_value,
is_nchw);
} else if (!optimize_flag & is_nchw) {
const int num_kernels = n * c * out_h * out_w;
const int num_threads = std::min(dev_ctx.GetMaxThreadsPerBlock(), 1024);
KeBilinearInterpNCHWBw<
T><<<backends::gpu::DivUp(num_kernels, num_threads),
num_threads,
0,
dev_ctx.stream()>>>(input_grad_data,
in_h,
in_w,
out_h,
out_w,
n,
c,
ratio_h,
ratio_w,
output_grad_data,
align_type_value);
} else {
int64_t cw = c * out_w;
auto interp_divmods = funcs::FastDivModForInterpolate(c, out_chw, cw);
KeBilinearInterpBw<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_grad_data,
in_h,
in_w,
output_grad_data,
out_h,
out_w,
n,
out_chw,
c,
ratio_h,
ratio_w,
align_type_value,
interp_divmods);
}
} else if ("bicubic" == interp_method) {
#ifdef __HIPCC__
constexpr int thread_per_block = 256;
#else
constexpr int thread_per_block = 512;
#endif
KeBicubicInterpBw<
T><<<config.block_per_grid, thread_per_block, 0, dev_ctx.stream()>>>(
input_grad_data,
in_h,
in_w,
n,
in_chw,
output_grad_data,
out_h,
out_w,
n,
out_chw,
c,
ratio_h,
ratio_w,
align_corners,
data_layout);
}
}
template <typename T, typename Context>
static void Interpolate3DCUDABwd(
const Context& dev_ctx,
const DenseTensor& input,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& output_grad,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* input_grad) {
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
if (scale_tensor) {
auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
if (scale_data.size() > 1) {
scale_d = scale_data[0];
scale_h = scale_data[1];
scale_w = scale_data[2];
} else {
scale_d = scale_data[0];
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0,
true,
errors::InvalidArgument(
"The scale_d in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
} else {
if (scale.size() > 1) {
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0,
true,
errors::InvalidArgument(
"The scale_d in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
}
}
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
out_d = static_cast<int>(in_d * scale_d);
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
if (out_size) {
DenseTensor sizes;
paddle::framework::TensorCopySync(
*out_size, paddle::platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
if (size_tensor && size_tensor->size() > 0) {
// have size tensor
auto new_size = funcs::get_new_shape(size_tensor.get());
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
}
auto* output_grad_data = output_grad.data<T>();
phi::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_d, in_h, in_w};
} else {
dim_grad = {n, in_d, in_h, in_w, c};
}
input_grad->Resize(dim_grad);
auto* input_grad_data = dev_ctx.template Alloc<T>(input_grad);
phi::funcs::SetConstant<Context, T> zero;
zero(dev_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) {
paddle::framework::TensorCopy(output_grad, dev_ctx.GetPlace(), input_grad);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
float new_scale_d = 0.f;
new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
: static_cast<float>(in_d) / out_d;
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(new_scale_d);
}
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
int64_t in_dhw = in_d * in_h * in_w;
int64_t out_dhw = out_d * out_h * out_w;
int64_t in_cdhw = c * in_dhw;
int64_t out_cdhw = c * out_dhw;
auto pixelNum = n * out_cdhw;
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, pixelNum);
if ("trilinear" == interp_method) {
KeTrilinearInterpBw<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_grad_data,
in_d,
in_h,
in_w,
n,
in_cdhw,
output_grad_data,
out_d,
out_h,
out_w,
n,
out_cdhw,
c,
ratio_d,
ratio_h,
ratio_w,
align_corners,
align_mode,
data_layout);
} else if ("nearest" == interp_method) {
KeNearestNeighbor3DInterpBw<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_grad_data,
in_d,
in_h,
in_w,
n,
in_cdhw,
output_grad_data,
out_d,
out_h,
out_w,
n,
out_cdhw,
c,
ratio_d,
ratio_h,
ratio_w,
align_corners,
data_layout);
}
}
template <typename T, typename Context>
void InterpolateGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& output_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
auto output_grad_dims = output_grad.dims();
if (output_grad_dims.size() == 3) { // 1D interpolation grad
Interpolate1DCUDABwd<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
output_grad,
data_layout,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
} else if (output_grad_dims.size() == 4) { // 2D interpolation grad
Interpolate2DCUDABwd<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
output_grad,
data_layout,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
} else if (output_grad_dims.size() == 5) { // 3D interpolation grad
Interpolate3DCUDABwd<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
output_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
}
template <typename T, typename Context>
void BilinearInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
InterpolateGradKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
out_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
template <typename T, typename Context>
void NearestInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
InterpolateGradKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
out_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
template <typename T, typename Context>
void TrilinearInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
InterpolateGradKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
out_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
template <typename T, typename Context>
void LinearInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
InterpolateGradKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
out_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
template <typename T, typename Context>
void BicubicInterpGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad) {
InterpolateGradKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
out_grad,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(bilinear_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::BilinearInterpGradKernel,
float,
double) {}
PD_REGISTER_KERNEL(nearest_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::NearestInterpGradKernel,
float,
double) {}
PD_REGISTER_KERNEL(trilinear_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::TrilinearInterpGradKernel,
float,
double) {}
PD_REGISTER_KERNEL(linear_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::LinearInterpGradKernel,
float,
double) {}
PD_REGISTER_KERNEL(bicubic_interp_v2_grad,
GPU,
ALL_LAYOUT,
phi::BicubicInterpGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/interpolate_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
namespace phi {
using paddle::platform::FastDivMod;
template <typename T>
__forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
int* in_img_idx,
int* x_id,
T* lambda1,
T* lambda2,
T src_x,
const int in_img_x) {
src_x = (src_x > 0) ? src_x : 0.f;
*in_img_idx = static_cast<int>(src_x);
*x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0;
*lambda1 = src_x - *in_img_idx;
*lambda2 = 1.f - *lambda1;
}
template <typename T>
__global__ void KeLinearInterpFw(const T* in,
const size_t in_img_w,
const size_t input_w,
T* out,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const float ratio_w,
const bool align_corners,
const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idx = tid % out_img_w;
} else {
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
const T* in_pos =
&in[out_id_h * out_id_w + channel_id * in_img_size + in_img_idx];
// linear interpolation
out[out_id_h * output_w + out_id_w] =
w2lambda * in_pos[0] + w1lambda * in_pos[w_id];
} else {
const T* in_pos =
&in[out_id_h * input_w + in_img_idx * num_channels + channel_id];
// linear interpolation
out[out_id_h * output_w + out_id_w] =
w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels];
}
}
}
template <typename T>
__global__ void KeNearestNeighborInterpNCHWFw(const T* in,
const size_t in_img_h,
const size_t in_img_w,
T* out,
const size_t out_img_h,
const size_t out_img_w,
const size_t nc,
const float ratio_h,
const float ratio_w,
const bool align_corners) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;
// nearest_sampling by multiple read in_addr and write to out_addr
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;
int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;
// prevent from multiple threads writing
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
out[out_index] = in[in_index];
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}
}
template <typename T>
__global__ void KeNearestNeighborInterpFw(
const T* in,
const size_t in_img_h,
const size_t in_img_w,
const size_t input_h,
const size_t input_w,
T* out,
const size_t out_img_h,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const float ratio_h,
const float ratio_w,
const bool align_corners,
funcs::FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_img_size = in_img_h * in_img_w;
int out_img_size = out_img_h * out_img_w;
for (; tid < nthreads; tid += stride) {
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
}
template <typename T>
__global__ void KeBilinearInterpFw(const T* in,
const size_t in_img_h,
const size_t in_img_w,
const size_t input_h,
const size_t input_w,
T* out,
const size_t out_img_h,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const float ratio_h,
const float ratio_w,
const T align_type_value,
funcs::FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w);
PreCalculatorForLinearInterpInputIndex(
&in_img_idy, &h_id, &h1lambda, &h2lambda, src_h, in_img_h);
// bilinear interpolation
const T* in_pos =
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
out[tid] =
h2lambda *
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
h1lambda *
(w2lambda * in_pos[h_id * in_img_w * num_channels] +
w1lambda *
in_pos[h_id * in_img_w * num_channels + w_id * num_channels]);
}
}
template <typename T>
__global__ void KeBilinearInterpNCHWFw(const T* in,
const size_t in_img_h,
const size_t in_img_w,
T* out,
const size_t out_img_h,
const size_t out_img_w,
const size_t nc,
const float ratio_h,
const float ratio_w,
const T align_type_value) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;
int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w);
PreCalculatorForLinearInterpInputIndex(
&in_img_idy, &h_id, &h1lambda, &h2lambda, src_h, in_img_h);
int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;
int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;
// prevent from multiple threads writing
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
const T* in_pos = &in[in_index];
out[out_index] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}
}
template <typename T>
__device__ __forceinline__ static T Kecubic_interp(
const T x0, const T x1, const T x2, const T x3, T t) {
T coeffs[4];
T a = -0.75;
T x_1 = t;
T x_2 = 1.0 - t;
coeffs[0] = funcs::CubicConvolution2<T>(x_1 + 1.0, a);
coeffs[1] = funcs::CubicConvolution1<T>(x_1, a);
coeffs[2] = funcs::CubicConvolution1<T>(x_2, a);
coeffs[3] = funcs::CubicConvolution2<T>(x_2 + 1.0, a);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename T>
__global__ void KeBicubicInterpFw(const T* in,
const size_t in_img_h,
const size_t in_img_w,
const size_t input_h,
const size_t input_w,
T* out,
const size_t out_img_h,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const float ratio_h,
const float ratio_w,
const bool align_corners,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
T in_img_idy = align_corners
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y;
T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx);
const T x_t = in_img_idx - input_x;
T coefficients[4];
const T* in_pos_0;
const T* in_pos_1;
const T* in_pos_2;
const T* in_pos_3;
int access_x_0;
if (data_layout == DataLayout::kNCHW) {
for (int k = 0; k < 4; k++) {
int access_y =
max(min(input_y - 1 + k, static_cast<int>(in_img_h - 1)), 0);
access_x_0 = max(min(input_x - 1, static_cast<int>(in_img_w - 1)), 0);
int access_x_1 =
max(min(input_x + 0, static_cast<int>(in_img_w - 1)), 0);
int access_x_2 =
max(min(input_x + 1, static_cast<int>(in_img_w - 1)), 0);
int access_x_3 =
max(min(input_x + 2, static_cast<int>(in_img_w - 1)), 0);
in_pos_0 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_0];
in_pos_1 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_1];
in_pos_2 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_2];
in_pos_3 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_3];
coefficients[k] = Kecubic_interp<T>(
in_pos_0[0], in_pos_1[0], in_pos_2[0], in_pos_3[0], x_t);
}
out[out_id_h * output_w + out_id_w] = Kecubic_interp<T>(coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
y_t);
} else {
for (int k = 0; k < 4; k++) {
int access_y =
max(min(input_y - 1 + k, static_cast<int>((in_img_h - 1))), 0);
int access_x_0 =
max(min(input_x - 1, static_cast<int>((in_img_w - 1))), 0);
int access_x_1 =
max(min(input_x + 0, static_cast<int>((in_img_w - 1))), 0);
int access_x_2 =
max(min(input_x + 1, static_cast<int>((in_img_w - 1))), 0);
int access_x_3 =
max(min(input_x + 2, static_cast<int>((in_img_w - 1))), 0);
const T* in_pos_0 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_0 * num_channels + channel_id];
const T* in_pos_1 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_1 * num_channels + channel_id];
const T* in_pos_2 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_2 * num_channels + channel_id];
const T* in_pos_3 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_3 * num_channels + channel_id];
coefficients[k] = Kecubic_interp(
in_pos_0[0], in_pos_1[0], in_pos_2[0], in_pos_3[0], x_t);
}
out[out_id_h * output_w + out_id_w] =
static_cast<T>(Kecubic_interp(coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
y_t));
}
}
}
template <typename T>
__global__ void KeTrilinearInterpFw(const T* in,
const size_t in_img_d,
const size_t in_img_h,
const size_t in_img_w,
const size_t input_h,
const size_t input_w,
T* out,
const size_t out_img_d,
const size_t out_img_h,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const float ratio_d,
const float ratio_h,
const float ratio_w,
const bool align_corners,
const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = align_flag
? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
: static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
(in_img_idt * in_img_h + in_img_idy) * in_img_w +
in_img_idx;
const T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
const T* in_pos2 = &in[in_pos2_idx];
// trilinear interpolation
out[out_id_h * output_w + out_id_w] =
d2lambda *
(h2lambda * (w2lambda * in_pos1[0] + w1lambda * in_pos1[w_id]) +
h1lambda * (w2lambda * in_pos1[h_id * in_img_w] +
w1lambda * in_pos1[h_id * in_img_w + w_id])) +
d1lambda *
(h2lambda * (w2lambda * in_pos2[0] + w1lambda * in_pos2[w_id]) +
h1lambda * (w2lambda * in_pos2[h_id * in_img_w] +
w1lambda * in_pos2[h_id * in_img_w + w_id]));
} else {
int in_pos1_idx = out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id;
const T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w * num_channels;
const T* in_pos2 = &in[in_pos2_idx];
// trilinear interpolation
out[out_id_h * output_w + out_id_w] =
d2lambda *
(h2lambda * (w2lambda * in_pos1[0] +
w1lambda * in_pos1[w_id * num_channels]) +
h1lambda * (w2lambda * in_pos1[h_id * in_img_w * num_channels] +
w1lambda * in_pos1[h_id * in_img_w * num_channels +
w_id * num_channels])) +
d1lambda *
(h2lambda * (w2lambda * in_pos2[0] +
w1lambda * in_pos2[w_id * num_channels]) +
h1lambda * (w2lambda * in_pos2[h_id * in_img_w * num_channels] +
w1lambda * in_pos2[h_id * in_img_w * num_channels +
w_id * num_channels]));
}
}
}
template <typename T>
__global__ void KeNearestNeighbor3DInterpFw(const T* in,
const size_t in_img_d,
const size_t in_img_h,
const size_t in_img_w,
const size_t input_h,
const size_t input_w,
T* out,
const size_t out_img_d,
const size_t out_img_h,
const size_t out_img_w,
const size_t output_h,
const size_t output_w,
const size_t num_channels,
const float ratio_d,
const float ratio_h,
const float ratio_w,
const bool align_corners,
const DataLayout data_layout) {
int nthreads = output_h * output_w; // ncdhw
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = (align_corners)
? static_cast<int>(ratio_d * out_img_idt + 0.5)
: static_cast<int>(ratio_d * out_img_idt);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
if (data_layout == DataLayout::kNCHW) {
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w +
in_img_idx];
} else {
out[tid] = in[out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
}
}
template <typename T, typename Context>
static void Interpolate1DCUDAFwd(
const Context& dev_ctx,
const DenseTensor& input,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout_str,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
auto* input_data = input.data<T>();
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
float scale_w = -1;
if (size_tensor && size_tensor->size() > 0) {
// have size tensor
auto new_size = funcs::get_new_shape(size_tensor.get());
out_w = new_size[0];
} else {
if (scale_tensor) {
auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
} else {
if (scale.size() > 0) {
scale_w = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
}
}
if (scale_w > 0.) {
out_w = static_cast<int>(in_w * scale_w);
}
if (out_size) {
DenseTensor sizes;
paddle::framework::TensorCopySync(
*out_size, paddle::platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_w = size_data[0];
}
}
PADDLE_ENFORCE_GT(
out_w,
0,
errors::InvalidArgument("out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_w};
} else {
dim_out = {n, out_w, c};
}
output->Resize(dim_out);
auto output_data = dev_ctx.template Alloc<T>(output);
if (in_w == out_w) {
paddle::framework::TensorCopy(input, dev_ctx.GetPlace(), output);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1.0) / (out_w - 1.0)
: static_cast<float>(new_scale_w);
}
int64_t in_cw = c * in_w;
int64_t out_cw = c * out_w;
auto pixelNum = n * out_cw;
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, pixelNum);
if ("linear" == interp_method) {
KeLinearInterpFw<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_data,
in_w,
in_cw,
output_data,
out_w,
n,
out_cw,
c,
ratio_w,
align_corners,
align_mode,
data_layout);
}
}
template <typename T, typename Context>
static void Interpolate2DCUDAFwd(
const Context& dev_ctx,
const DenseTensor& input,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout_str,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
auto* input_data = input.data<T>();
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
float scale_w = -1;
float scale_h = -1;
if (size_tensor && size_tensor->size() > 0) {
// have size tensor
auto new_size = funcs::get_new_shape(size_tensor.get());
out_h = new_size[0];
out_w = new_size[1];
} else {
if (scale_tensor) {
auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
} else {
if (scale.size() > 1) {
scale_w = scale[1];
scale_h = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
}
}
if (scale_w > 0. && scale_h > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
if (out_size) {
DenseTensor sizes;
paddle::framework::TensorCopySync(
*out_size, paddle::platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
}
PADDLE_ENFORCE_GT(
out_h,
0,
errors::InvalidArgument("out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(
out_w,
0,
errors::InvalidArgument("out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_h, out_w};
} else {
dim_out = {n, out_h, out_w, c};
}
output->Resize(dim_out);
auto output_data = dev_ctx.template Alloc<T>(output);
if (in_h == out_h && in_w == out_w) {
paddle::framework::TensorCopy(input, dev_ctx.GetPlace(), output);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
int64_t in_hw = in_h * in_w;
int64_t out_hw = out_h * out_w;
int64_t in_chw = c * in_hw;
int64_t out_chw = c * out_hw;
auto pixelNum = n * out_chw;
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, pixelNum);
if ("nearest" == interp_method) {
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
int nc = n * c;
backends::gpu::GpuLaunchConfig config_3d =
backends::gpu::GetGpuLaunchConfig3D(dev_ctx, nc, out_h, out_w);
KeNearestNeighborInterpNCHWFw<T><<<config_3d.block_per_grid,
config_3d.thread_per_block,
0,
dev_ctx.stream()>>>(input_data,
in_h,
in_w,
output_data,
out_h,
out_w,
nc,
ratio_h,
ratio_w,
align_corners);
} else {
int64_t cw = c * out_w;
auto interp_divmods = funcs::FastDivModForInterpolate(c, out_chw, cw);
KeNearestNeighborInterpFw<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_data,
in_h,
in_w,
n,
in_chw,
output_data,
out_h,
out_w,
n,
out_chw,
c,
ratio_h,
ratio_w,
align_corners,
interp_divmods);
}
} else if ("bilinear" == interp_method) {
dim3 thread_num = config.thread_per_block;
#ifdef WITH_NV_JETSON
if (config.compute_capability == 53 || config.compute_capability == 62) {
thread_num = 512;
}
#endif
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
int nc = n * c;
backends::gpu::GpuLaunchConfig config_3d =
backends::gpu::GetGpuLaunchConfig3D(dev_ctx, nc, out_h, out_w);
KeBilinearInterpNCHWFw<T><<<config_3d.block_per_grid,
config_3d.thread_per_block,
0,
dev_ctx.stream()>>>(input_data,
in_h,
in_w,
output_data,
out_h,
out_w,
nc,
ratio_h,
ratio_w,
align_type_value);
} else {
int64_t cw = c * out_w;
auto interp_divmods = funcs::FastDivModForInterpolate(c, out_chw, cw);
KeBilinearInterpFw<
T><<<config.block_per_grid, thread_num, 0, dev_ctx.stream()>>>(
input_data,
in_h,
in_w,
n,
in_chw,
output_data,
out_h,
out_w,
n,
out_chw,
c,
ratio_h,
ratio_w,
align_type_value,
interp_divmods);
}
} else if ("bicubic" == interp_method) {
#ifdef __HIPCC__
constexpr int thread_per_block = 256;
#else
constexpr int thread_per_block = 512;
#endif
KeBicubicInterpFw<
T><<<config.block_per_grid, thread_per_block, 0, dev_ctx.stream()>>>(
input_data,
in_h,
in_w,
n,
in_chw,
output_data,
out_h,
out_w,
n,
out_chw,
c,
ratio_h,
ratio_w,
align_corners,
data_layout);
}
}
template <typename T, typename Context>
static void Interpolate3DCUDAFwd(
const Context& dev_ctx,
const DenseTensor& input,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
auto* input_data = input.data<T>();
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
float scale_w = -1;
float scale_d = -1;
float scale_h = -1;
if (size_tensor && size_tensor->size() > 0) {
// have size tensor
auto new_size = funcs::get_new_shape(size_tensor.get());
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
} else {
if (scale_tensor) {
auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
if (scale_data.size() > 1) {
scale_d = scale_data[0];
scale_h = scale_data[1];
scale_w = scale_data[2];
} else {
scale_d = scale_data[0];
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0,
true,
errors::InvalidArgument(
"The scale_d in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
} else {
if (scale.size() > 1) {
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0,
true,
errors::InvalidArgument(
"The scale_d in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
}
}
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
out_d = static_cast<int>(in_d * scale_d);
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
if (out_size) {
DenseTensor sizes;
paddle::framework::TensorCopySync(
*out_size, paddle::platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
}
PADDLE_ENFORCE_GT(
out_d,
0,
errors::InvalidArgument("out_d in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(
out_h,
0,
errors::InvalidArgument("out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(
out_w,
0,
errors::InvalidArgument("out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_d, out_h, out_w};
} else {
dim_out = {n, out_d, out_h, out_w, c};
}
output->Resize(dim_out);
auto output_data = dev_ctx.template Alloc<T>(output);
if (in_d == out_d && in_h == out_h && in_w == out_w) {
paddle::framework::TensorCopy(input, dev_ctx.GetPlace(), output);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
float new_scale_d = 0.f;
new_scale_d = (scale_d > 0) ? static_cast<float>(1. / scale_d)
: static_cast<float>(in_d) / out_d;
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(new_scale_d);
}
if (out_h > 1) {
float new_scale_h = 0.f;
new_scale_h = (scale_h > 0) ? static_cast<float>(1. / scale_h)
: static_cast<float>(in_h) / out_h;
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(new_scale_h);
}
if (out_w > 1) {
float new_scale_w = 0.f;
new_scale_w = (scale_w > 0) ? static_cast<float>(1. / scale_w)
: static_cast<float>(in_w) / out_w;
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(new_scale_w);
}
int64_t in_dhw = in_d * in_h * in_w;
int64_t out_dhw = out_d * out_h * out_w;
int64_t in_cdhw = c * in_dhw;
int64_t out_cdhw = c * out_dhw;
auto pixelNum = n * out_cdhw;
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, pixelNum);
if ("trilinear" == interp_method) {
KeTrilinearInterpFw<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_data,
in_d,
in_h,
in_w,
n,
in_cdhw,
output_data,
out_d,
out_h,
out_w,
n,
out_cdhw,
c,
ratio_d,
ratio_h,
ratio_w,
align_corners,
align_mode,
data_layout);
} else if ("nearest" == interp_method) {
KeNearestNeighbor3DInterpFw<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_data,
in_d,
in_h,
in_w,
n,
in_cdhw,
output_data,
out_d,
out_h,
out_w,
n,
out_cdhw,
c,
ratio_d,
ratio_h,
ratio_w,
align_corners,
data_layout);
}
}
template <typename T, typename Context>
void InterpolateKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
auto input_dims = x.dims();
if (input_dims.size() == 3) { // 1D interpolation
Interpolate1DCUDAFwd<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
} else if (input_dims.size() == 4) { // 2D interpolation
Interpolate2DCUDAFwd<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
} else if (input_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDAFwd<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
}
}
template <typename T, typename Context>
void BilinearInterpKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
InterpolateKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
}
template <typename T, typename Context>
void NearestInterpKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
InterpolateKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
}
template <typename T, typename Context>
void TrilinearInterpKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
InterpolateKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
}
template <typename T, typename Context>
void LinearInterpKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
InterpolateKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
}
template <typename T, typename Context>
void BicubicInterpKernel(
const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
InterpolateKernel<T, Context>(dev_ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output);
}
} // namespace phi
PD_REGISTER_KERNEL(bilinear_interp_v2,
GPU,
ALL_LAYOUT,
phi::BilinearInterpKernel,
float,
double,
int) {}
PD_REGISTER_KERNEL(nearest_interp_v2,
GPU,
ALL_LAYOUT,
phi::NearestInterpKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(trilinear_interp_v2,
GPU,
ALL_LAYOUT,
phi::TrilinearInterpKernel,
float,
double,
int) {}
PD_REGISTER_KERNEL(linear_interp_v2,
GPU,
ALL_LAYOUT,
phi::LinearInterpKernel,
float,
double,
int) {}
PD_REGISTER_KERNEL(bicubic_interp_v2,
GPU,
ALL_LAYOUT,
phi::BicubicInterpKernel,
float,
double,
int) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BilinearInterpGradKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BilinearInterpKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output);
template <typename T, typename Context>
void NearestInterpKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output);
template <typename T, typename Context>
void TrilinearInterpKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output);
template <typename T, typename Context>
void LinearInterpKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output);
template <typename T, typename Context>
void BicubicInterpKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature BilinearInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("bilinear_interp_v2",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"Out"});
}
KernelSignature NearestInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("nearest_interp_v2",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"Out"});
}
KernelSignature TrilinearInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("trilinear_interp_v2",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"Out"});
}
KernelSignature LinearInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("linear_interp_v2",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"Out"});
}
KernelSignature BicubicInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("bicubic_interp_v2",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"Out"});
}
KernelSignature BilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"bilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
}
KernelSignature NearestInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"nearest_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
}
KernelSignature TrilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"trilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
}
KernelSignature LinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"linear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
}
KernelSignature BicubicInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"bicubic_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(bilinear_interp_v2,
phi::BilinearInterpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(nearest_interp_v2,
phi::NearestInterpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(trilinear_interp_v2,
phi::TrilinearInterpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(linear_interp_v2,
phi::LinearInterpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(bicubic_interp_v2,
phi::BicubicInterpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(bilinear_interp_v2_grad,
phi::BilinearInterpGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(nearest_interp_v2_grad,
phi::NearestInterpGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(trilinear_interp_v2_grad,
phi::TrilinearInterpGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(linear_interp_v2_grad,
phi::LinearInterpGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(bicubic_interp_v2_grad,
phi::BicubicInterpGradOpArgumentMapping);
...@@ -41,7 +41,9 @@ class TrtConvertNearestInterpV2Test(TrtLayerAutoScanTest): ...@@ -41,7 +41,9 @@ class TrtConvertNearestInterpV2Test(TrtLayerAutoScanTest):
"data_layout": "NCHW", "data_layout": "NCHW",
"interp_method": "nearest", "interp_method": "nearest",
"align_corners": False, "align_corners": False,
"align_mode": 1,
"scale": [2., 2.], "scale": [2., 2.],
"out_d": 0,
"out_h": 0, "out_h": 0,
"out_w": 0 "out_w": 0
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册