未验证 提交 66655ad8 编写于 作者: W WuHaobo 提交者: GitHub

test=release/1.8 cherry-pick unfold_op (#24505)

上级 ec0f78a3
......@@ -61,10 +61,12 @@ class UnfoldOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UnfoldOp should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of UnfoldOp should not be null");
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of UnfoldOp should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Y"), true,
platform::errors::NotFound("Output(Y) of UnfoldOp should not be null"));
auto in_dims = ctx->GetInputDim("X");
std::vector<int> kernel_sizes =
ctx->Attrs().Get<std::vector<int>>("kernel_sizes");
......@@ -74,31 +76,36 @@ class UnfoldOp : public framework::OperatorWithKernel {
ctx->Attrs().Get<std::vector<int>>("dilations");
// Only [N, C, H, W] input supported now
PADDLE_ENFORCE(
in_dims.size() == 4,
PADDLE_ENFORCE_EQ(
in_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be 4-D tensor of format [N, C, H, W], but get %u",
in_dims.size());
PADDLE_ENFORCE(
in_dims.size() - kernel_sizes.size() == 2U,
in_dims.size()));
PADDLE_ENFORCE_EQ(
in_dims.size() - kernel_sizes.size(), 2U,
platform::errors::InvalidArgument(
"The dims of X should be larger than that of kernel_sizes "
"by a number of 2, due to the batch size and input channel dim. "
"But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2",
in_dims.size(), kernel_sizes.size());
in_dims.size(), kernel_sizes.size()));
PADDLE_ENFORCE_EQ(
strides.size(), kernel_sizes.size(),
platform::errors::InvalidArgument(
"The dims of strides should be the same with that of kernel_sizes. "
"But recieved dims(strides: %u) != dims(kernel_sizes: %u).",
strides.size(), kernel_sizes.size());
strides.size(), kernel_sizes.size()));
PADDLE_ENFORCE_EQ(
paddings.size(), 2 * strides.size(),
platform::errors::InvalidArgument(
"The dims of paddings should be 2 times of that of strides. "
"But recieved dims(paddings: %u) != 2*dims(strides: %u).",
paddings.size(), strides.size());
paddings.size(), strides.size()));
PADDLE_ENFORCE_EQ(
strides.size(), dilations.size(),
platform::errors::InvalidArgument(
"The dims of strides should be the same with that of dilations. "
"But recieved dims(strides: %u) != dims(dilations: %u).",
strides.size(), dilations.size());
strides.size(), dilations.size()));
std::vector<int> out_dims;
out_dims.push_back(in_dims[0]);
......@@ -131,11 +138,15 @@ class UnfoldGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"The gradient of Y should not be null");
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"The gradient of X should not be null");
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Y")), true,
platform::errors::NotFound("The gradient of Y should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("The input X should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound("The gradient of X should not be null"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
......
......@@ -29,12 +29,15 @@ inline int CalcOutputSize(int input_size, int filter_size, int dilation,
int padding1, int padding2, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1;
PADDLE_ENFORCE(output_size > 0,
PADDLE_ENFORCE_GT(
output_size, 0UL,
platform::errors::InvalidArgument(
"Due to the settings of padding(%d, %d), filter_size(%d), "
"dilation(%d) and "
"stride(%d), the output size is less than 0, please check "
"again. Input_size:%d",
padding1, padding2, filter_size, dilation, stride, input_size);
padding1, padding2, filter_size, dilation, stride, input_size));
return output_size;
}
......
......@@ -15257,6 +15257,8 @@ def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None):
helper = LayerHelper("unfold", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'unfold')
assert len(x.shape) == 4, \
"input should be the format of [N, C, H, W]"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册