未验证 提交 d65a7a46 编写于 作者: C chentianyu03 提交者: GitHub

[Phi]Interploatd kernels into phi (#40855)

* add interploate cpu kernel

* fix nullptr bug

* add interpolate gpu kernel

* fix unit test error

* remove raw kernels

* add cuda kernel impl

* add infermeta

* recover accidentally deleted kernels in interpolate op

* fix grad x_grad name error

* remove interpolate_v2_op.h

* rm unused codes

* fix xpu build error

* fix build error

* fix namespace error

* add register header for nup

* fix infermeta error

* modify by review

* add the missing args in test_trt_convert_nearest_interp_v2
上级 597d7efd
......@@ -2167,7 +2167,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
typeid(paddle::optional<const phi::DenseTensor&>)) ||
input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<const phi::SelectedRows&>)))) {
typeid(paddle::optional<const phi::SelectedRows&>)) ||
input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<
const std::vector<const phi::DenseTensor*>>)))) {
pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1;
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx),
......@@ -2429,6 +2433,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
std::type_index(typeid(std::vector<std::string>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr_it->second));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<float>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr_it->second));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
......
......@@ -272,6 +272,14 @@ void BuildDygraphPhiKernelContext(
auto end_idx = start_idx + 1;
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
continue;
} else if (input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<
const std::vector<const phi::DenseTensor*>>))) {
kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1;
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
continue;
} else {
PADDLE_THROW(phi::errors::NotFound(
"Can not find input variable '%s' for %s OP, please check whether "
......@@ -545,6 +553,9 @@ void BuildDygraphPhiKernelContext(
std::type_index(typeid(std::vector<std::string>))) {
kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<float>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<float>, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
......
......@@ -9,11 +9,15 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
......@@ -722,64 +726,51 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateV2GradNoNeedBufferVarsInferer,
// not
// compatible with interp_op, so a new one is added in paddle2.0
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(bilinear_interp_v2, BilinearInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(nearest_interp_v2, NearestInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(trilinear_interp_v2,
TrilinearInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(bicubic_interp_v2, BicubicInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(linear_interp_v2, LinearInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
REGISTER_OPERATOR(bilinear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
BilinearInterpInferShapeFunctor);
REGISTER_OPERATOR(bilinear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(nearest_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
NearestInterpInferShapeFunctor);
REGISTER_OPERATOR(nearest_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(trilinear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
TrilinearInterpInferShapeFunctor);
REGISTER_OPERATOR(trilinear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(bicubic_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
BicubicInterpInferShapeFunctor);
REGISTER_OPERATOR(bicubic_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(bilinear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<int>,
ops::InterpolateV2Kernel<int64_t>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(nearest_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OPERATOR(linear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
LinearInterpInferShapeFunctor);
REGISTER_OPERATOR(linear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(linear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(linear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
此差异已折叠。
......@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
namespace paddle {
namespace operators {
......@@ -401,7 +403,8 @@ class InterpolateV2NPUKernel : public framework::OpKernel<T> {
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h, &in_w);
phi::funcs::ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h,
&in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
......@@ -431,14 +434,15 @@ class InterpolateV2NPUKernel : public framework::OpKernel<T> {
out_w = output_w[0];
} else if (ctx.HasInput("OutSize")) {
auto out_size = ctx.Input<Tensor>("OutSize");
auto out_size_data = get_new_data_from_tensor<int>(out_size);
auto out_size_data = phi::funcs::get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
} else {
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
auto scale_data =
phi::funcs::get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
......@@ -538,7 +542,8 @@ class InterpolateV2NPUGradKernel : public framework::OpKernel<T> {
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
phi::funcs::ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h,
&in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
......@@ -567,14 +572,15 @@ class InterpolateV2NPUGradKernel : public framework::OpKernel<T> {
out_w = output_w[0];
} else if (ctx.HasInput("OutSize")) {
auto out_size = ctx.Input<Tensor>("OutSize");
auto out_size_data = get_new_data_from_tensor<int>(out_size);
auto out_size_data = phi::funcs::get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
} else {
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
auto scale_data =
phi::funcs::get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
......
......@@ -14,8 +14,7 @@
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
#ifdef PADDLE_WITH_XPU
namespace paddle {
......@@ -57,7 +56,8 @@ class InterpolateV2XPUKernel : public framework::OpKernel<T> {
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h, &in_w);
phi::funcs::ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h,
&in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
......@@ -78,7 +78,8 @@ class InterpolateV2XPUKernel : public framework::OpKernel<T> {
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
auto scale_data =
phi::funcs::get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
......@@ -107,7 +108,8 @@ class InterpolateV2XPUKernel : public framework::OpKernel<T> {
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
auto out_size_data =
phi::funcs::get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
}
......@@ -169,7 +171,8 @@ class InterpolateV2GradXPUKernel : public framework::OpKernel<T> {
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
phi::funcs::ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h,
&in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
......@@ -190,7 +193,8 @@ class InterpolateV2GradXPUKernel : public framework::OpKernel<T> {
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
auto scale_data =
phi::funcs::get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
......@@ -219,7 +223,8 @@ class InterpolateV2GradXPUKernel : public framework::OpKernel<T> {
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
auto out_size_data =
phi::funcs::get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
}
......
......@@ -179,6 +179,43 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context,
return config;
}
static inline int GetLastPow2(int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return std::max(1, n - (n >> 1));
}
inline GpuLaunchConfig GetGpuLaunchConfig3D(const phi::GPUContext& context,
int num_img,
int height,
int width) {
const int kThreadsPerBlock = 256;
int max_threads_per_block = context.GetMaxThreadsPerBlock(); // 1024
int max_threads = std::min(kThreadsPerBlock, max_threads_per_block);
int block_x = std::min(GetLastPow2(width), max_threads);
int block_y = std::min(GetLastPow2(height), max_threads / block_x);
int block_z = std::min(num_img, max_threads / block_x / block_y);
auto max_grid_dim = context.GetCUDAMaxGridDimSize();
int grid_x =
std::min<int>(max_grid_dim[0], backends::gpu::DivUp(width, block_x));
int grid_y =
std::min<int>(max_grid_dim[1], backends::gpu::DivUp(height, block_y));
int grid_z = std::min<int>(max_grid_dim[2],
backends::gpu::DivUp(num_img, block_z * 4));
const int capability = context.GetComputeCapability();
GpuLaunchConfig config;
config.compute_capability = capability;
config.thread_per_block = dim3(block_x, block_y, block_z);
config.block_per_grid = dim3(grid_x, grid_y, grid_z);
return config;
}
} // namespace gpu
} // namespace backends
} // namespace phi
......
......@@ -87,6 +87,23 @@ std::vector<MetaTensor*> InferMetaContext::InputsBetween(size_t start,
return result;
}
paddle::optional<const std::vector<const MetaTensor*>>
InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
const auto& first = inputs_.at(start);
if (first) {
std::vector<const MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.push_back(inputs_.at(i).get());
}
return paddle::optional<const std::vector<const MetaTensor*>>(result);
}
return paddle::optional<const std::vector<const MetaTensor*>>(paddle::none);
}
MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get();
}
......
......@@ -54,6 +54,8 @@ class InferMetaContext {
const MetaTensor& InputAt(size_t idx) const;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(size_t idx) const;
std::vector<MetaTensor*> InputsBetween(size_t start, size_t end) const;
paddle::optional<const std::vector<const phi::MetaTensor*>>
OptionalInputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx);
std::vector<MetaTensor*> MutableOutputBetween(size_t start, size_t end);
......@@ -174,6 +176,26 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};
template <typename... Tail>
struct InferMetaFnCallHelper<
paddle::optional<const std::vector<const MetaTensor*>>,
Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"InferMeta's Input should appear before Attributes.");
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
paddle::optional<const std::vector<const MetaTensor*>> arg =
ctx->OptionalInputsBetween(range.first, range.second);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
// TODO(chenweihang): support other attr type later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
......
......@@ -97,6 +97,22 @@ class KernelContext {
return v;
}
template <typename TensorType>
paddle::optional<const std::vector<const TensorType*>> OptionalInputsBetween(
size_t start, size_t end) {
const auto& first = inputs_.at(start);
if (first) {
std::vector<const TensorType*> v;
for (size_t i = start; i < end; ++i) {
auto* t = static_cast<const TensorType*>(inputs_.at(i));
v.emplace_back(t);
}
return paddle::optional<const std::vector<const TensorType*>>(v);
}
return paddle::optional<const std::vector<const TensorType*>>(paddle::none);
}
template <typename TensorType>
TensorType* MutableOutputAt(size_t idx) {
return static_cast<TensorType*>(outputs_.at(idx));
......
......@@ -81,6 +81,13 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
paddle::optional<
const std::vector<const DenseTensor*>>))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
paddle::optional<const SelectedRows&>))) {
args_def->AppendInput(default_key.backend(),
......
......@@ -126,6 +126,30 @@ namespace phi {
} \
}
#define PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_MULTI_INPUT(tensor_type) \
template <typename... Tail> \
struct KernelCallHelper< \
paddle::optional<const std::vector<const tensor_type*>>, \
Tail...> { \
template <int dev_ctx_idx, \
int in_idx, \
int attr_idx, \
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
static_assert(attr_idx == 0, \
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
paddle::optional<const std::vector<const tensor_type*>> arg = \
ctx->OptionalInputsBetween<tensor_type>(range.first, range.second); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
ctx, pargs..., arg); \
} \
}
#define PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type) \
template <typename... Tail> \
struct KernelCallHelper<attr_type, Tail...> { \
......@@ -224,6 +248,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_MULTI_INPUT(DenseTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor);
......
......@@ -890,6 +890,506 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
static void Interpolate1DInferShapeCheck(
const MetaTensor& x,
paddle::optional<const MetaTensor&> out_size,
paddle::optional<const std::vector<const MetaTensor*>> size_tensor,
paddle::optional<const MetaTensor&> scale_tensor,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
MetaTensor* output,
MetaConfig config) {
auto dim_x = x.dims();
PADDLE_ENFORCE_EQ("linear",
interp_method,
phi::errors::InvalidArgument(
"Interpolation method can only be \"linear\" when"
"Input(X) dimension is 3, but got method = %s .",
interp_method));
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
for (int i = 0; i < dim_x.size(); ++i) {
PADDLE_ENFORCE_NE(
dim_x[i],
0,
phi::errors::InvalidArgument("The shape of input(x) should be larged "
"than 0, bug received shape[%d] is %d ",
i,
dim_x[i]));
}
if (size_tensor && size_tensor->size() > 0) {
// top prority size
auto inputs_name = size_tensor.get();
PADDLE_ENFORCE_EQ(
inputs_name.size(),
1,
phi::errors::InvalidArgument(
"Input(SizeTensor)'size of Op(interpolate) must be 1. "
"Attr(out_shape)'s length must be 1 for 3-D input tensor, but got "
"size = %d .",
inputs_name.size()));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_w};
} else {
dim_out = {dim_x[0], out_w, dim_x[2]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
return;
}
int out_w_tmp;
if (scale_tensor) {
auto scale_tensor_dim = scale_tensor->dims();
PADDLE_ENFORCE_EQ(
scale_tensor_dim.size(),
1,
phi::errors::InvalidArgument(
"Scale's dimension size must be 1, but got dimension = %d .",
scale_tensor_dim.size()));
PADDLE_ENFORCE_EQ(scale_tensor_dim[0],
1,
phi::errors::InvalidArgument(
"Scale's shape must be 1, but got shape = %d .",
scale_tensor_dim[0]));
out_w_tmp = -1;
} else {
if (scale.size() > 0) {
float scale_w = -1;
scale_w = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
phi::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
if (scale_w > 0.) {
// round down
out_w_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale_w)
: static_cast<int>(dim_x[1] * scale_w));
// protect when input shape is -1
out_w_tmp = out_w_tmp > 0 ? out_w_tmp : -1;
}
} else {
out_w_tmp = out_w;
}
}
if (out_size && config.is_runtime) {
auto out_size_dim = out_size->dims();
PADDLE_ENFORCE_EQ(
out_size_dim.size(),
1,
phi::errors::InvalidArgument(
"OutSize's dimension size must be 1, but got dimention = %d .",
out_size_dim.size()));
PADDLE_ENFORCE_EQ(
out_size_dim[0],
1,
phi::errors::InvalidArgument(
"OutSize's 0-th dimension's value must be 1, but got value = %d .",
out_size_dim[0]));
// dims will be seted in kernel
output->set_dtype(x.dtype());
output->share_lod(x);
return;
}
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_w_tmp};
} else {
dim_out = {dim_x[0], out_w_tmp, dim_x[2]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
}
static void Interpolate2DInferShapeCheck(
const MetaTensor& x,
paddle::optional<const MetaTensor&> out_size,
paddle::optional<const std::vector<const MetaTensor*>> size_tensor,
paddle::optional<const MetaTensor&> scale_tensor,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
MetaTensor* output,
MetaConfig config) {
auto dim_x = x.dims();
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method ||
"bicubic" == interp_method,
phi::errors::InvalidArgument(
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4, but got method = %s.",
interp_method));
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
for (int i = 0; i < dim_x.size(); ++i) {
PADDLE_ENFORCE_NE(
dim_x[i],
0,
phi::errors::InvalidArgument("The shape of input(x) should be larged "
"than 0, bug received shape[%d] is %d ",
i,
dim_x[i]));
}
if (size_tensor && size_tensor->size()) {
// top prority size
auto inputs_name = size_tensor.get();
PADDLE_ENFORCE_EQ(
inputs_name.size(),
2,
phi::errors::InvalidArgument(
"Input(SizeTensor)'size of Op(interpolate) must be 2. "
"Attr(out_shape)'s length must be 2 for 4-D input "
"tensor, but got size = %d .",
inputs_name.size()));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_h, out_w};
} else {
dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
return;
}
int out_h_tmp, out_w_tmp;
if (scale_tensor) {
auto scale_tensor_dim = scale_tensor->dims();
PADDLE_ENFORCE_EQ(
scale_tensor_dim.size(),
1,
phi::errors::InvalidArgument(
"Scale's dimension size must be 1, but got dimension = %d .",
scale_tensor_dim.size()));
PADDLE_ENFORCE_EQ(scale_tensor_dim[0] == 2 || scale_tensor_dim[0] == 1,
true,
phi::errors::InvalidArgument(
"Scale's shape must be 2 or 1, but got shape = %d .",
scale_tensor_dim[0]));
out_h_tmp = -1;
out_w_tmp = -1;
} else {
if (scale.size() > 0) {
float scale_h = -1;
float scale_w = -1;
scale_h = scale[0];
scale_w = scale[1];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
phi::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
phi::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
if (scale_h > 0. && scale_w > 0.) {
// round down
out_h_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale_h)
: static_cast<int>(dim_x[1] * scale_h));
out_w_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[3] * scale_w)
: static_cast<int>(dim_x[2] * scale_w));
// protect when input shape is -1
out_h_tmp = out_h_tmp > 0 ? out_h_tmp : -1;
out_w_tmp = out_w_tmp > 0 ? out_w_tmp : -1;
}
} else {
out_h_tmp = out_h;
out_w_tmp = out_w;
}
}
if (out_size && config.is_runtime) {
auto out_size_dim = out_size->dims();
PADDLE_ENFORCE_EQ(
out_size_dim.size(),
1,
phi::errors::InvalidArgument(
"OutSize's dimension size must be 1, but got dimension = %d .",
out_size_dim.size()));
PADDLE_ENFORCE_EQ(
out_size_dim[0],
2,
phi::errors::InvalidArgument(
"OutSize's dim[0] must be 2, but got dimention = %d .",
out_size_dim[0]));
// dims will be seted in kernel
output->set_dtype(x.dtype());
output->share_lod(x);
return;
}
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_h_tmp, out_w_tmp};
} else {
dim_out = {dim_x[0], out_h_tmp, out_w_tmp, dim_x[3]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
}
static void Interpolate3DInferShapeCheck(
const MetaTensor& x,
paddle::optional<const MetaTensor&> out_size,
paddle::optional<const std::vector<const MetaTensor*>> size_tensor,
paddle::optional<const MetaTensor&> scale_tensor,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
MetaTensor* output,
MetaConfig config) {
auto dim_x = x.dims();
PADDLE_ENFORCE("nearest" == interp_method || "trilinear" == interp_method,
phi::errors::InvalidArgument(
"Interpolation method can only be \"trilinear\" or "
"\"nearest\" when Input(X) "
"dimension is 5, but got method = %s .",
interp_method));
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
for (int i = 0; i < dim_x.size(); ++i) {
PADDLE_ENFORCE_NE(
dim_x[i],
0,
phi::errors::InvalidArgument("The shape of input(x) should be larged "
"than 0, bug received shape[%d] is %d ",
i,
dim_x[i]));
}
if (size_tensor && size_tensor->size() > 0) {
// top prority size
auto inputs_name = size_tensor.get();
PADDLE_ENFORCE_EQ(
inputs_name.size(),
3,
phi::errors::InvalidArgument(
"Input(SizeTensor)'s size of Op(interpolate) must be 3. "
"Attr(out_shape)'s length must be 3 for 5-D input "
"tensor, but got size = %d .",
inputs_name.size()));
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
} else {
dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
return;
}
int out_d_tmp, out_h_tmp, out_w_tmp;
if (scale_tensor) {
auto scale_tensor_dim = scale_tensor->dims();
PADDLE_ENFORCE_EQ(
scale_tensor_dim.size(),
1,
phi::errors::InvalidArgument(
"Scale's dimension size must be 1, but got size = %d .",
scale_tensor_dim.size()));
PADDLE_ENFORCE_EQ(scale_tensor_dim[0] == 3 || scale_tensor_dim[0] == 1,
true,
phi::errors::InvalidArgument(
"Scale's shape must be 3 or 1, but got shape = %d .",
scale_tensor_dim[0]));
out_d_tmp = -1;
out_h_tmp = -1;
out_w_tmp = -1;
} else {
if (scale.size() > 0) {
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0,
true,
phi::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0,
true,
phi::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
PADDLE_ENFORCE_EQ(
scale_d > 0,
true,
phi::errors::InvalidArgument(
"The scale_d in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_d));
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
// round down
out_d_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale_d)
: static_cast<int>(dim_x[1] * scale_d));
out_h_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[3] * scale_h)
: static_cast<int>(dim_x[2] * scale_h));
out_w_tmp = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[4] * scale_w)
: static_cast<int>(dim_x[3] * scale_w));
// protect when input shape is -1
out_d_tmp = out_d_tmp > 0 ? out_d_tmp : -1;
out_h_tmp = out_h_tmp > 0 ? out_h_tmp : -1;
out_w_tmp = out_w_tmp > 0 ? out_w_tmp : -1;
}
} else {
out_d_tmp = out_d;
out_h_tmp = out_h;
out_w_tmp = out_w;
}
}
if (out_size && config.is_runtime) {
auto out_size_dim = out_size->dims();
PADDLE_ENFORCE_EQ(
out_size_dim.size(),
1,
phi::errors::InvalidArgument(
"OutSize's dimension size must be 1, but got size is %d.",
out_size_dim.size()));
PADDLE_ENFORCE_EQ(out_size_dim[0],
3,
phi::errors::InvalidArgument(
"OutSize's dim[0] must be 3, but got size is %d.",
out_size_dim[0]));
// dims will be seted in kernel
output->set_dtype(x.dtype());
output->share_lod(x);
return;
}
phi::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_d_tmp, out_h_tmp, out_w_tmp};
} else {
dim_out = {dim_x[0], out_d_tmp, out_h_tmp, out_w_tmp, dim_x[4]};
}
output->set_dims(dim_out);
output->set_dtype(x.dtype());
}
void InterpolateInferMeta(
const MetaTensor& x,
paddle::optional<const MetaTensor&> out_size,
paddle::optional<const std::vector<const MetaTensor*>> size_tensor,
paddle::optional<const MetaTensor&> scale_tensor,
const std::string& data_layout_str,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
MetaTensor* output,
MetaConfig config) {
auto dim_x = x.dims(); // NCHW format
PADDLE_ENFORCE(
dim_x.size() == 3 || dim_x.size() == 4 || dim_x.size() == 5,
phi::errors::Unimplemented(
"Input(X) dimension must be 3, 4 or 5, but got dimension = %d .",
dim_x.size()));
if (dim_x.size() == 3) {
// shape check for 1D interpolate for input tensor shape NCHW
Interpolate1DInferShapeCheck(x,
out_size,
size_tensor,
scale_tensor,
data_layout_str,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output,
config);
} else if (dim_x.size() == 4) {
// shape check for 2D interpolate for input tensor shape NCHW
Interpolate2DInferShapeCheck(x,
out_size,
size_tensor,
scale_tensor,
data_layout_str,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output,
config);
} else { // dim_x.size() == 5
// shape check for 3D interpolate for input tensor shape NCDHW
Interpolate3DInferShapeCheck(x,
out_size,
size_tensor,
scale_tensor,
data_layout_str,
out_d,
out_h,
out_w,
scale,
interp_method,
align_corners,
align_mode,
output,
config);
}
}
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x);
......
......@@ -199,6 +199,22 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x,
MetaTensor* pre_out,
MetaTensor* w_out);
void InterpolateInferMeta(
const MetaTensor& x,
paddle::optional<const MetaTensor&> out_size,
paddle::optional<const std::vector<const MetaTensor*>> size_tensor,
paddle::optional<const MetaTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
MetaTensor* output,
MetaConfig config = MetaConfig());
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
......
此差异已折叠。
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/phi/core/hostdevice.h"
namespace phi {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/platform/fast_divmod.h"
#endif
namespace phi {
namespace funcs {
template <typename T>
HOSTDEVICE inline T CubicConvolution1(T x, T A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
}
template <typename T>
HOSTDEVICE inline T CubicConvolution2(T x, T A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
template <typename T>
HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) {
T A = -0.75;
T x1 = t;
coeffs[0] = CubicConvolution2<T>(x1 + 1.0, A);
coeffs[1] = CubicConvolution1<T>(x1, A);
// opposite coefficients
T x2 = 1.0 - t;
coeffs[2] = CubicConvolution1<T>(x2, A);
coeffs[3] = CubicConvolution2<T>(x2 + 1.0, A);
}
inline void ExtractNCDWH(const DDim& dims,
const DataLayout& data_layout,
int* N,
int* C,
int* D,
int* H,
int* W) {
*N = dims[0];
if (dims.size() == 3) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[2];
*D = 1;
*H = 1;
*W = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
} else if (dims.size() == 4) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[3];
*D = 1;
*H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*W = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
} else {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[4];
*D = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*H = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
*W = data_layout == DataLayout::kNCHW ? dims[4] : dims[3];
}
}
inline std::vector<int> get_new_shape(
const std::vector<const DenseTensor*>& list_new_shape_tensor) {
// get tensor from
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(
tensor->dims(),
phi::make_ddim({1}),
errors::InvalidArgument("The shape of dimension tensor should be [1],"
"but received d%.",
tensor->dims()));
if (paddle::platform::is_gpu_place(tensor->place())) {
DenseTensor temp;
paddle::framework::TensorCopySync(
*tensor, paddle::platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
}
return vec_new_shape;
}
template <typename T>
inline std::vector<T> get_new_data_from_tensor(
const DenseTensor* new_data_tensor) {
std::vector<T> vec_new_data;
auto* new_data = new_data_tensor->data<T>();
DenseTensor cpu_starts_tensor;
if (paddle::platform::is_gpu_place(new_data_tensor->place())) {
paddle::framework::TensorCopySync(
*new_data_tensor, paddle::platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#ifdef PADDLE_WITH_ASCEND_CL
if (paddle::platform::is_npu_place(new_data_tensor->place())) {
paddle::framework::TensorCopySync(
*new_data_tensor, paddle::platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#endif
#ifdef PADDLE_WITH_XPU
if (paddle::platform::is_xpu_place(new_data_tensor->place())) {
paddle::framework::TensorCopySync(
*new_data_tensor, paddle::platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#endif
vec_new_data = std::vector<T>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
#if defined(__NVCC__) || defined(__HIPCC__)
using paddle::platform::FastDivMod;
struct FastDivModForInterpolate {
public:
FastDivMod channels_div;
FastDivMod output_w_div;
FastDivMod output_wc_div;
explicit HOSTDEVICE FastDivModForInterpolate(const int channels,
const int output_w,
const int outout_wc)
: channels_div(FastDivMod(channels)),
output_w_div(FastDivMod(output_w)),
output_wc_div(FastDivMod(outout_wc)) {}
};
#endif
} // namespace funcs
} // namespace phi
此差异已折叠。
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BilinearInterpGradKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const DenseTensor& out_grad,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BilinearInterpKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output);
template <typename T, typename Context>
void NearestInterpKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output);
template <typename T, typename Context>
void TrilinearInterpKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output);
template <typename T, typename Context>
void LinearInterpKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output);
template <typename T, typename Context>
void BicubicInterpKernel(
const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out_size,
paddle::optional<const std::vector<const DenseTensor*>> size_tensor,
paddle::optional<const DenseTensor&> scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature BilinearInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("bilinear_interp_v2",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"Out"});
}
KernelSignature NearestInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("nearest_interp_v2",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"Out"});
}
KernelSignature TrilinearInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("trilinear_interp_v2",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"Out"});
}
KernelSignature LinearInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("linear_interp_v2",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"Out"});
}
KernelSignature BicubicInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("bicubic_interp_v2",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"Out"});
}
KernelSignature BilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"bilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
}
KernelSignature NearestInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"nearest_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
}
KernelSignature TrilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"trilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
}
KernelSignature LinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"linear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
}
KernelSignature BicubicInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"bicubic_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(bilinear_interp_v2,
phi::BilinearInterpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(nearest_interp_v2,
phi::NearestInterpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(trilinear_interp_v2,
phi::TrilinearInterpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(linear_interp_v2,
phi::LinearInterpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(bicubic_interp_v2,
phi::BicubicInterpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(bilinear_interp_v2_grad,
phi::BilinearInterpGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(nearest_interp_v2_grad,
phi::NearestInterpGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(trilinear_interp_v2_grad,
phi::TrilinearInterpGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(linear_interp_v2_grad,
phi::LinearInterpGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(bicubic_interp_v2_grad,
phi::BicubicInterpGradOpArgumentMapping);
......@@ -41,7 +41,9 @@ class TrtConvertNearestInterpV2Test(TrtLayerAutoScanTest):
"data_layout": "NCHW",
"interp_method": "nearest",
"align_corners": False,
"align_mode": 1,
"scale": [2., 2.],
"out_d": 0,
"out_h": 0,
"out_w": 0
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册