diff --git a/paddle/fluid/operators/unfold_op.cc b/paddle/fluid/operators/unfold_op.cc index 0a36b6ef8408873a532883031e831a5ea21838e0..5c0eb64993b55646203a91e7bbc6d978105e8027 100644 --- a/paddle/fluid/operators/unfold_op.cc +++ b/paddle/fluid/operators/unfold_op.cc @@ -107,6 +107,42 @@ class UnfoldOp : public framework::OperatorWithKernel { "But recieved dims(strides: %u) != dims(dilations: %u).", strides.size(), dilations.size())); + // check kernel_sizes + PADDLE_ENFORCE_GT(kernel_sizes[0], 0, + platform::errors::InvalidArgument( + "The `kernel_sizes` should be greater than zero, " + "but recieved kernel_height: %d kernel_width: %d.", + kernel_sizes[0], kernel_sizes[1])); + PADDLE_ENFORCE_GT(kernel_sizes[1], 0, + platform::errors::InvalidArgument( + "The `kernel_sizes` should be greater than zero, " + "but recieved kernel_height: %d kernel_width: %d.", + kernel_sizes[0], kernel_sizes[1])); + // check strides + PADDLE_ENFORCE_GT(strides[0], 0, + platform::errors::InvalidArgument( + "The `strides` should be greater than zero, " + "but recieved strides_height: %d strides_width: %d.", + strides[0], strides[1])); + PADDLE_ENFORCE_GT(strides[1], 0, + platform::errors::InvalidArgument( + "The `strides` should be greater than zero, " + "but recieved strides_height: %d strides_width: %d.", + strides[0], strides[1])); + // check dilations + PADDLE_ENFORCE_GT( + dilations[0], 0, + platform::errors::InvalidArgument( + "The `dilations` should be greater than zero, " + "but recieved dilations_height: %d dilations_width: %d.", + dilations[0], dilations[1])); + PADDLE_ENFORCE_GT( + dilations[1], 0, + platform::errors::InvalidArgument( + "The `dilations` should be greater than zero, " + "but recieved dilations_height: %d dilations_width: %d.", + dilations[0], dilations[1])); + std::vector out_dims; out_dims.push_back(in_dims[0]);