提交 439d95e1 编写于 作者: Z Zhang Ting 提交者: Aurelius84

modified interpolate op to support tensor attribute, test=develop, test=document_preview (#19287)

modified interpolate_op to support tensor attribute

1. the parameter out_shape of image_resize、resize_nearest/bilinear/trilinear can be a list or a 1-D tensor variable. If a list, each element can be an integer or a tensor variable with shape: [1].

2. the parameter scale of above Ops can be a 1-D tensor variable.
modified document of image_resize, resize_nearest, resize_bilinear, resize_trilinear and add some code example.
上级 b3888941
......@@ -188,11 +188,11 @@ paddle.fluid.layers.label_smooth (ArgSpec(args=['label', 'prior_dist', 'epsilon'
paddle.fluid.layers.roi_pool (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)), ('document', '49368d724023a66b41b0071be41c0ba5'))
paddle.fluid.layers.roi_align (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)), ('document', '9a7a3b88a4fae41d58d3ca9b10ba0591'))
paddle.fluid.layers.dice_loss (ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)), ('document', '7e8e4bf1f0f8612961ed113e8af8f0c5'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', '8cfc4f69dbbedb687b6c20732aa8f09e'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', '0e8567334d72a214c2e3ce0ce19e4d37'))
paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', 'bd97ebfe4bdf5110a5fcb8ecb626a447'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '832b2412652d84a6631b1012c6e2d18b'))
paddle.fluid.layers.resize_trilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '4836e98a634f6fbea26d0cdaa303f867'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '32ffc0e8818d7319ed1bf63a791e985d'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '0a7b98e57eb74bab6e3c2a95e41298a7'))
paddle.fluid.layers.resize_trilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '6baf2ddf375d3059e5aa74d7fde76517'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '699bf1de6af91235367e9c7a9a6e252c'))
paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c'))
paddle.fluid.layers.gather_nd (ArgSpec(args=['input', 'index', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '3cc24f9cf135770aa6263dba25b457f9'))
paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab'))
......
......@@ -29,20 +29,41 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4");
if (ctx->HasInputs("SizeTensor")) {
// top prority size
auto inputs_name = ctx->Inputs("SizeTensor");
PADDLE_ENFORCE_EQ(
inputs_name.size(), 2,
"Input(SizeTensor)'size of Op(interpolate) must be 2. "
"Attr(out_shape)'s length must be 2 for 4-D input tensor.");
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
return;
}
int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
// protect when input shape is -1
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
if (ctx->HasInput("Scale")) {
auto scale_tensor = ctx->GetInputDim("Scale");
PADDLE_ENFORCE_EQ(scale_tensor.size(), 1,
"Scale's dimension size must be 1.");
out_h = -1;
out_w = -1;
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
// protect when input shape is -1
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
}
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
......@@ -66,24 +87,46 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
"Interpolation method can only be \"trilinear\" when Input(X) "
"dimension is 5");
if (ctx->HasInputs("SizeTensor")) {
// top prority size
auto inputs_name = ctx->Inputs("SizeTensor");
PADDLE_ENFORCE_EQ(
inputs_name.size(), 3,
"Input(SizeTensor)'s size of Op(interpolate) must be 3. "
"Attr(out_shape)'s length must be 3 for 5-D input tensor.");
int out_d = ctx->Attrs().Get<int>("out_d");
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_d, out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
return;
}
int out_d, out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_d = static_cast<int>(dim_x[2] * scale);
out_h = static_cast<int>(dim_x[3] * scale);
out_w = static_cast<int>(dim_x[4] * scale);
// protect when input shape is -1
out_d = out_d > 0 ? out_d : -1;
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
if (ctx->HasInput("Scale")) {
auto scale_tensor = ctx->GetInputDim("Scale");
PADDLE_ENFORCE_EQ(scale_tensor.size(), 1,
"Scale's dimension size must be 1");
out_d = -1;
out_h = -1;
out_w = -1;
} else {
out_d = ctx->Attrs().Get<int>("out_d");
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_d, 0, "out_d should be greater than 0.");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_d = static_cast<int>(dim_x[2] * scale);
out_h = static_cast<int>(dim_x[3] * scale);
out_w = static_cast<int>(dim_x[4] * scale);
// protect when input shape is -1
out_d = out_d > 0 ? out_d : -1;
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_d = ctx->Attrs().Get<int>("out_d");
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
}
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
......@@ -129,6 +172,16 @@ class InterpolateOp : public framework::OperatorWithKernel {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "SizeTensor" || var_name == "Scale") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -142,7 +195,19 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
"This is a 1-D tensor with two numbers to specify output size. "
"It should be [output_height, output_width] when input is a 4-D "
"tensor and should be [output_depth, output_height, output_width] "
"when input is a 5-D tensor.")
"when input is a 5-D tensor. It has a higher priority than "
"the attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
.AsDispensable();
AddInput("SizeTensor",
"(vector<Tensor<int32>>, optional). If provided, interpolate will "
"use this. The shape of the tensor in vector MUST BE [1]. "
"It has the highest priority compare with Input(OutSize) and "
"attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
.AsDuplicable()
.AsDispensable();
AddInput("Scale",
"This is a 1-D tensor with one number to specify output scale. "
"It has the higher priority compare with attr(scale).")
.AsDispensable();
AddOutput("Out",
"The output tensor of interpolate operator, "
......@@ -304,6 +369,16 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "SizeTensor" || var_name == "Scale") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
......@@ -315,9 +390,15 @@ class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType(ForwardOp().Type() + "_grad");
op->SetInput("X", Input("X"));
if (ForwardOp().Inputs().count("SizeTensor") > 0) {
op->SetInput("SizeTensor", Input("SizeTensor"));
}
if (ForwardOp().Inputs().count("OutSize") > 0) {
op->SetInput("OutSize", Input("OutSize"));
}
if (ForwardOp().Inputs().count("Scale") > 0) {
op->SetInput("Scale", Input("Scale"));
}
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
......
......@@ -365,20 +365,41 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_h = new_size[0];
out_w = new_size[1];
} else {
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
}
PADDLE_ENFORCE_GT(
out_h, 0,
"out_h in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_w, 0,
"out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
auto output_data =
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
......@@ -439,22 +460,47 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
} else {
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
}
PADDLE_ENFORCE_GT(
out_d, 0,
"out_d in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_h, 0,
"out_h in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_w, 0,
"out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
auto output_data =
output->mutable_data<T>({n, c, out_d, out_h, out_w}, ctx.GetPlace());
......@@ -513,7 +559,14 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
......@@ -522,11 +575,18 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_h = new_size[0];
out_w = new_size[1];
}
auto* output_grad_data = output_grad.data<T>();
auto* input_grad_data =
......@@ -591,7 +651,14 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
......@@ -601,12 +668,20 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
}
auto* output_grad_data = output_grad.data<T>();
auto* input_grad_data =
......
......@@ -23,6 +23,40 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
inline std::vector<int> get_new_shape(
const std::vector<const Tensor*>& list_new_shape_tensor) {
// get tensor from
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
"shape of dim tensor should be [1]");
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
}
return vec_new_shape;
}
template <typename T>
inline std::vector<T> get_new_data_from_tensor(const Tensor* new_data_tensor) {
std::vector<T> vec_new_data;
auto* new_data = new_data_tensor->data<T>();
framework::Tensor cpu_starts_tensor;
if (platform::is_gpu_place(new_data_tensor->place())) {
TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
vec_new_data = std::vector<T>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
template <typename T>
static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
......@@ -403,19 +437,39 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_h = new_size[0];
out_w = new_size[1];
} else {
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
}
}
PADDLE_ENFORCE_GT(
out_h, 0,
"out_h in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_w, 0,
"out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
if (in_h == out_h && in_w == out_w) {
......@@ -459,21 +513,45 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_d = out_size_data[0];
out_h = out_size_data[1];
out_w = out_size_data[2];
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
} else {
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_d = out_size_data[0];
out_h = out_size_data[1];
out_w = out_size_data[2];
}
}
PADDLE_ENFORCE_GT(
out_d, 0,
"out_d in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_h, 0,
"out_h in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_w, 0,
"out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
output->mutable_data<T>({n, c, out_d, out_h, out_w}, ctx.GetPlace());
if (in_d == out_d && in_h == out_h && in_w == out_w) {
......@@ -519,18 +597,31 @@ static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_h = new_size[0];
out_w = new_size[1];
}
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
......@@ -580,20 +671,34 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_d = out_size_data[0];
out_h = out_size_data[1];
out_w = out_size_data[2];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
}
input_grad->mutable_data<T>({n, c, in_d, in_h, in_w}, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
......
此差异已折叠。
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
def bilinear_interp_np(input,
......@@ -359,5 +360,154 @@ class TestBilinearInterpZero(TestBilinearInterpOp):
self.align_mode = 0
class TestBilinearInterpOp_attr_tensor(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "bilinear_interp"
self.shape_by_1Dtensor = False
self.scale_by_1Dtensor = False
self.attrs = {
'interp_method': self.interp_method,
'align_corners': self.align_corners,
}
input_np = np.random.random(self.input_shape).astype("float32")
self.inputs = {'X': input_np}
if self.scale_by_1Dtensor:
self.inputs['Scale'] = np.array([self.scale]).astype("float32")
elif self.scale > 0:
out_h = int(self.input_shape[2] * self.scale)
out_w = int(self.input_shape[3] * self.scale)
self.attrs['scale'] = self.scale
else:
out_h = self.out_h
out_w = self.out_w
if self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.out_size
elif self.out_size is not None:
size_tensor = []
for index, ele in enumerate(self.out_size):
size_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs['SizeTensor'] = size_tensor
self.attrs['out_h'] = self.out_h
self.attrs['out_w'] = self.out_w
output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size,
self.actual_shape, self.align_corners)
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 4, 4]
self.out_h = 3
self.out_w = 3
self.scale = 0.
self.out_size = [3, 3]
self.align_corners = True
# out_size is a 1-D tensor
class TestBilinearInterp_attr_tensor_Case1(TestBilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = [8, 12]
self.align_corners = True
# scale is a 1-D tensor
class TestBilinearInterp_attr_tensor_Case2(TestBilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
self.shape_by_1Dtensor = True
# scale is a 1-D tensor
class TestBilinearInterp_attr_tensor_Case3(TestBilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 2.0
self.out_size = None
self.align_corners = True
self.scale_by_1Dtensor = True
class TestBilinearInterpOpAPI(OpTest):
def test_case(self):
x = fluid.layers.data(name="x", shape=[3, 6, 6], dtype="float32")
dim = fluid.layers.data(
name="dim", shape=[1], dtype="int32", append_batch_size=False)
shape_tensor = fluid.layers.data(
name="shape_tensor",
shape=[2],
dtype="int32",
append_batch_size=False)
actual_size = fluid.layers.data(
name="actual_size",
shape=[2],
dtype="int32",
append_batch_size=False)
scale_tensor = fluid.layers.data(
name="scale_tensor",
shape=[1],
dtype="float32",
append_batch_size=False)
out1 = fluid.layers.resize_bilinear(x, out_shape=[12, 12])
out2 = fluid.layers.resize_bilinear(x, out_shape=[12, dim])
out3 = fluid.layers.resize_bilinear(x, out_shape=shape_tensor)
out4 = fluid.layers.resize_bilinear(
x, out_shape=[4, 4], actual_shape=actual_size)
out5 = fluid.layers.resize_bilinear(x, scale=scale_tensor)
x_data = np.random.random((1, 3, 6, 6)).astype("float32")
dim_data = np.array([12]).astype("int32")
shape_data = np.array([12, 12]).astype("int32")
actual_size_data = np.array([12, 12]).astype("int32")
scale_data = np.array([2.0]).astype("float32")
place = core.CPUPlace()
exe = fluid.Executor(place)
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
"scale_tensor": scale_data
},
fetch_list=[out1, out2, out3, out4, out5],
return_numpy=True)
expect_res = bilinear_interp_np(
x_data, out_h=12, out_w=12, align_corners=True)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
if __name__ == "__main__":
unittest.main()
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
def nearest_neighbor_interp_np(X,
......@@ -299,5 +300,155 @@ class TestNearestNeighborInterpScale3(TestNearestInterpOp):
self.align_corners = True
class TestNearestInterpOp_attr_tensor(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "nearest_interp"
self.shape_by_1Dtensor = False
self.scale_by_1Dtensor = False
self.attrs = {
'interp_method': self.interp_method,
'align_corners': self.align_corners,
}
input_np = np.random.random(self.input_shape).astype("float32")
self.inputs = {'X': input_np}
if self.scale_by_1Dtensor:
self.inputs['Scale'] = np.array([self.scale]).astype("float32")
elif self.scale > 0:
out_h = int(self.input_shape[2] * self.scale)
out_w = int(self.input_shape[3] * self.scale)
self.attrs['scale'] = self.scale
else:
out_h = self.out_h
out_w = self.out_w
if self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.out_size
elif self.out_size is not None:
size_tensor = []
for index, ele in enumerate(self.out_size):
size_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs['SizeTensor'] = size_tensor
self.attrs['out_h'] = self.out_h
self.attrs['out_w'] = self.out_w
output_np = nearest_neighbor_interp_np(input_np, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners)
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 4, 4]
self.out_h = 3
self.out_w = 3
self.scale = 0.
self.out_size = [3, 3]
self.align_corners = True
# out_size is a tensor list
class TestNearestInterp_attr_tensor_Case1(TestNearestInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = [8, 12]
self.align_corners = True
# out_size is a 1-D tensor
class TestNearestInterp_attr_tensor_Case2(TestNearestInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
self.shape_by_1Dtensor = True
# scale is a 1-D tensor
class TestNearestInterp_attr_tensor_Case3(TestNearestInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 2.0
self.out_size = None
self.align_corners = True
self.scale_by_1Dtensor = True
class TestNearestAPI(OpTest):
def test_case(self):
x = fluid.layers.data(name="x", shape=[3, 6, 6], dtype="float32")
dim = fluid.layers.data(
name="dim", shape=[1], dtype="int32", append_batch_size=False)
shape_tensor = fluid.layers.data(
name="shape_tensor",
shape=[2],
dtype="int32",
append_batch_size=False)
actual_size = fluid.layers.data(
name="actual_size",
shape=[2],
dtype="int32",
append_batch_size=False)
scale_tensor = fluid.layers.data(
name="scale_tensor",
shape=[1],
dtype="float32",
append_batch_size=False)
out1 = fluid.layers.resize_nearest(x, out_shape=[12, 12])
out2 = fluid.layers.resize_nearest(x, out_shape=[12, dim])
out3 = fluid.layers.resize_nearest(x, out_shape=shape_tensor)
out4 = fluid.layers.resize_nearest(
x, out_shape=[4, 4], actual_shape=actual_size)
out5 = fluid.layers.resize_nearest(x, scale=scale_tensor)
x_data = np.random.random((1, 3, 6, 6)).astype("float32")
dim_data = np.array([12]).astype("int32")
shape_data = np.array([12, 12]).astype("int32")
actual_size_data = np.array([12, 12]).astype("int32")
scale_data = np.array([2.0]).astype("float32")
place = core.CPUPlace()
exe = fluid.Executor(place)
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
"scale_tensor": scale_data
},
fetch_list=[out1, out2, out3, out4, out5],
return_numpy=True)
expect_res = nearest_neighbor_interp_np(
x_data, out_h=12, out_w=12, align_corners=True)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
if __name__ == "__main__":
unittest.main()
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
def trilinear_interp_np(input,
......@@ -424,5 +425,166 @@ class TestTrilinearInterpZero(TestTrilinearInterpOp):
self.align_mode = 0
class TestTrilinearInterpOp_attr_tensor(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "trilinear_interp"
self.shape_by_1Dtensor = False
self.scale_by_1Dtensor = False
self.attrs = {
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
}
input_np = np.random.random(self.input_shape).astype("float32")
self.inputs = {'X': input_np}
if self.scale_by_1Dtensor:
self.inputs['Scale'] = np.array([self.scale]).astype("float32")
elif self.scale > 0:
out_d = int(self.input_shape[2] * self.scale)
out_h = int(self.input_shape[3] * self.scale)
out_w = int(self.input_shape[4] * self.scale)
self.attrs['scale'] = self.scale
else:
out_d = self.out_d
out_h = self.out_h
out_w = self.out_w
if self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.out_size
elif self.out_size is not None:
size_tensor = []
for index, ele in enumerate(self.out_size):
size_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs['SizeTensor'] = size_tensor
self.attrs['out_d'] = self.out_d
self.attrs['out_h'] = self.out_h
self.attrs['out_w'] = self.out_w
output_np = trilinear_interp_np(input_np, out_d, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners, self.align_mode)
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 4, 4, 4]
self.out_d = 2
self.out_h = 3
self.out_w = 3
self.scale = 0.
self.out_size = [2, 3, 3]
self.align_corners = True
self.align_mode = 1
# out_size is a 1-D tensor
class TestTrilinearInterp_attr_tensor_Case1(TestTrilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 2, 9, 6, 8]
self.out_d = 32
self.out_h = 16
self.out_w = 8
self.scale = 0.3
self.out_size = [12, 4, 4]
self.align_corners = True
self.align_mode = 1
# scale is a 1-D tensor
class TestTrilinearInterp_attr_tensor_Case2(TestTrilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 8, 8, 4]
self.out_d = 16
self.out_h = 12
self.out_w = 4
self.scale = 0.
self.out_size = [16, 4, 10]
self.align_corners = True
self.align_mode = 1
self.shape_by_1Dtensor = True
# scale is a 1-D tensor
class TestTrilinearInterp_attr_tensor_Case3(TestTrilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 8, 8, 4]
self.out_d = 16
self.out_h = 16
self.out_w = 8
self.scale = 2.0
self.out_size = None
self.align_corners = True
self.align_mode = 1
self.scale_by_1Dtensor = True
class TestTrilinearInterpAPI(OpTest):
def test_case(self):
x = fluid.layers.data(name="x", shape=[3, 6, 9, 4], dtype="float32")
dim = fluid.layers.data(name="dim", shape=[1], dtype="int32")
shape_tensor = fluid.layers.data(
name="shape_tensor",
shape=[3],
dtype="int32",
append_batch_size=False)
actual_size = fluid.layers.data(
name="actual_size",
shape=[3],
dtype="int32",
append_batch_size=False)
scale_tensor = fluid.layers.data(
name="scale_tensor",
shape=[1],
dtype="float32",
append_batch_size=False)
out1 = fluid.layers.resize_trilinear(x, out_shape=[12, 18, 8])
out2 = fluid.layers.resize_trilinear(x, out_shape=[12, dim, 8])
out3 = fluid.layers.resize_trilinear(x, out_shape=shape_tensor)
out4 = fluid.layers.resize_trilinear(
x, out_shape=[4, 4, 8], actual_shape=actual_size)
out5 = fluid.layers.resize_trilinear(x, scale=scale_tensor)
x_data = np.random.random((1, 3, 6, 9, 4)).astype("float32")
dim_data = np.array([18]).astype("int32")
shape_data = np.array([12, 18, 8]).astype("int32")
actual_size_data = np.array([12, 18, 8]).astype("int32")
scale_data = np.array([2.0]).astype("float32")
place = core.CPUPlace()
exe = fluid.Executor(place)
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
"scale_tensor": scale_data
},
fetch_list=[out1, out2, out3, out4, out5],
return_numpy=True)
expect_res = trilinear_interp_np(
x_data, out_d=12, out_h=18, out_w=8, align_mode=1)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册