未验证 提交 1c00732d 编写于 作者: W WuHaobo 提交者: GitHub

Polish PADDLE_ENFORCE of unfold_op (#24423)

上级 71ff32b6
...@@ -61,10 +61,12 @@ class UnfoldOp : public framework::OperatorWithKernel { ...@@ -61,10 +61,12 @@ class UnfoldOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_EQ(
"Input(X) of UnfoldOp should not be null"); ctx->HasInput("X"), true,
PADDLE_ENFORCE(ctx->HasOutput("Y"), platform::errors::NotFound("Input(X) of UnfoldOp should not be null"));
"Output(Y) of UnfoldOp should not be null"); PADDLE_ENFORCE_EQ(
ctx->HasOutput("Y"), true,
platform::errors::NotFound("Output(Y) of UnfoldOp should not be null"));
auto in_dims = ctx->GetInputDim("X"); auto in_dims = ctx->GetInputDim("X");
std::vector<int> kernel_sizes = std::vector<int> kernel_sizes =
ctx->Attrs().Get<std::vector<int>>("kernel_sizes"); ctx->Attrs().Get<std::vector<int>>("kernel_sizes");
...@@ -74,31 +76,36 @@ class UnfoldOp : public framework::OperatorWithKernel { ...@@ -74,31 +76,36 @@ class UnfoldOp : public framework::OperatorWithKernel {
ctx->Attrs().Get<std::vector<int>>("dilations"); ctx->Attrs().Get<std::vector<int>>("dilations");
// Only [N, C, H, W] input supported now // Only [N, C, H, W] input supported now
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
in_dims.size() == 4, in_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be 4-D tensor of format [N, C, H, W], but get %u", "Input should be 4-D tensor of format [N, C, H, W], but get %u",
in_dims.size()); in_dims.size()));
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
in_dims.size() - kernel_sizes.size() == 2U, in_dims.size() - kernel_sizes.size(), 2U,
platform::errors::InvalidArgument(
"The dims of X should be larger than that of kernel_sizes " "The dims of X should be larger than that of kernel_sizes "
"by a number of 2, due to the batch size and input channel dim. " "by a number of 2, due to the batch size and input channel dim. "
"But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2", "But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2",
in_dims.size(), kernel_sizes.size()); in_dims.size(), kernel_sizes.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
strides.size(), kernel_sizes.size(), strides.size(), kernel_sizes.size(),
platform::errors::InvalidArgument(
"The dims of strides should be the same with that of kernel_sizes. " "The dims of strides should be the same with that of kernel_sizes. "
"But recieved dims(strides: %u) != dims(kernel_sizes: %u).", "But recieved dims(strides: %u) != dims(kernel_sizes: %u).",
strides.size(), kernel_sizes.size()); strides.size(), kernel_sizes.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddings.size(), 2 * strides.size(), paddings.size(), 2 * strides.size(),
platform::errors::InvalidArgument(
"The dims of paddings should be 2 times of that of strides. " "The dims of paddings should be 2 times of that of strides. "
"But recieved dims(paddings: %u) != 2*dims(strides: %u).", "But recieved dims(paddings: %u) != 2*dims(strides: %u).",
paddings.size(), strides.size()); paddings.size(), strides.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
strides.size(), dilations.size(), strides.size(), dilations.size(),
platform::errors::InvalidArgument(
"The dims of strides should be the same with that of dilations. " "The dims of strides should be the same with that of dilations. "
"But recieved dims(strides: %u) != dims(dilations: %u).", "But recieved dims(strides: %u) != dims(dilations: %u).",
strides.size(), dilations.size()); strides.size(), dilations.size()));
std::vector<int> out_dims; std::vector<int> out_dims;
out_dims.push_back(in_dims[0]); out_dims.push_back(in_dims[0]);
...@@ -131,11 +138,15 @@ class UnfoldGradOp : public framework::OperatorWithKernel { ...@@ -131,11 +138,15 @@ class UnfoldGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), PADDLE_ENFORCE_EQ(
"The gradient of Y should not be null"); ctx->HasInput(framework::GradVarName("Y")), true,
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null"); platform::errors::NotFound("The gradient of Y should not be null"));
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE_EQ(
"The gradient of X should not be null"); ctx->HasInput("X"), true,
platform::errors::NotFound("The input X should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound("The gradient of X should not be null"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
......
...@@ -29,12 +29,14 @@ inline int CalcOutputSize(int input_size, int filter_size, int dilation, ...@@ -29,12 +29,14 @@ inline int CalcOutputSize(int input_size, int filter_size, int dilation,
int padding1, int padding2, int stride) { int padding1, int padding2, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1; const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1; int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1;
PADDLE_ENFORCE(output_size > 0, PADDLE_ENFORCE_GT(
output_size, 0UL,
platform::errors::InvalidArgument(
"Due to the settings of padding(%d, %d), filter_size(%d), " "Due to the settings of padding(%d, %d), filter_size(%d), "
"dilation(%d) and " "dilation(%d) and "
"stride(%d), the output size is less than 0, please check " "stride(%d), the output size is less than 0, please check "
"again. Input_size:%d", "again. Input_size:%d",
padding1, padding2, filter_size, dilation, stride, input_size); padding1, padding2, filter_size, dilation, stride, input_size));
return output_size; return output_size;
} }
......
...@@ -13990,6 +13990,8 @@ def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None): ...@@ -13990,6 +13990,8 @@ def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None):
helper = LayerHelper("unfold", **locals()) helper = LayerHelper("unfold", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'unfold')
assert len(x.shape) == 4, \ assert len(x.shape) == 4, \
"input should be the format of [N, C, H, W]" "input should be the format of [N, C, H, W]"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册