未验证 提交 6c332ad6 编写于 作者: L liuwei1031 提交者: GitHub

imporve error messages for conv, conv_transpose, cos_sim, group_norm (#23675)

* imporve error messages for conv, conv_transpose, cos_sim, group_norm
上级 05476e9f
...@@ -31,12 +31,9 @@ namespace paddle { ...@@ -31,12 +31,9 @@ namespace paddle {
namespace operators { namespace operators {
void ConvOp::InferShape(framework::InferShapeContext* ctx) const { void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv");
"Input(Input) of ConvOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv");
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true, OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Conv");
"Input(Filter) of ConvOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true,
"Output(Output) of ConvOp should not be null.");
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
...@@ -56,49 +53,53 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -56,49 +53,53 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5, true, in_dims.size() == 4 || in_dims.size() == 5, true,
"ShapeError: the input of Op(conv) should be 4-D or 5-D Tensor. But " platform::errors::InvalidArgument(
"received: %u-D Tensor, the shape of input is [%s].", "The input of Op(conv) should be 4-D or 5-D Tensor. But "
in_dims.size(), in_dims); "received: %u-D Tensor, the shape of input is [%s].",
in_dims.size(), in_dims));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(), in_dims.size(), filter_dims.size(),
"ShapeError: the input's dimension size and filter's dimension size of " platform::errors::InvalidArgument(
"Op(conv) should be equal. But received: the shape of input is [%s], " "The input's dimension size and filter's dimension size of "
"the dimension size of input is [%d], the shape of filter is [%s], " "Op(conv) should be equal. But received: the shape of input is [%s], "
"the dimension size of filter is [%d].", "the dimension size of input is [%d], the shape of filter is [%s], "
in_dims, in_dims.size(), filter_dims, filter_dims.size()); "the dimension size of filter is [%d].",
in_dims, in_dims.size(), filter_dims, filter_dims.size()));
int in_sub_stride_size = in_dims.size() - strides.size(); int in_sub_stride_size = in_dims.size() - strides.size();
PADDLE_ENFORCE_EQ(in_dims.size() - strides.size() == 2U, true, PADDLE_ENFORCE_EQ(
"ShapeError: the dimension size of input minus the size of " in_dims.size(), strides.size() + 2U,
"Attr(stride) must be euqal to 2 for Op(conv)." platform::errors::InvalidArgument(
"But received: the dimension size of input minus the size " "The dimension size of input minus the size of "
"of Attr(stride) is [%d], the " "Attr(stride) must be euqal to 2 for Op(conv)."
"input's dimension size is [%d], the shape of input " "But received: the dimension size of input minus the size "
"is [%s], the Attr(stride)'s size is [%d].", "of Attr(stride) is [%d], the "
in_sub_stride_size, in_dims.size(), in_dims, "input's dimension size is [%d], the shape of input "
strides.size()); "is [%s], the Attr(stride)'s size is [%d].",
in_sub_stride_size, in_dims.size(), in_dims, strides.size()));
const auto input_channels = const auto input_channels =
channel_last ? in_dims[in_dims.size() - 1] : in_dims[1]; channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_channels, filter_dims[1] * groups, input_channels, filter_dims[1] * groups,
"ShapeError: The number of input channels should be equal to filter " platform::errors::InvalidArgument(
"channels * groups for Op(conv). But received: the input's channels is " "The number of input channels should be equal to filter channels * "
"[%d], the shape " "groups for Op(conv). But received: the input's channels is [%d], "
"of input is [%s], the filter's channel is [%d], the shape of filter is " "the shape of input is [%s], the filter's channel is [%d], the shape "
"[%s], the groups is [%d], the data_format is %s. The error may come " "of filter is [%s], the groups is [%d], the data_format is %s. The "
"from wrong data_format setting.", "error may come from wrong data_format setting.",
input_channels, in_dims, filter_dims[1], filter_dims, groups, input_channels, in_dims, filter_dims[1], filter_dims, groups,
data_format); data_format));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0, filter_dims[0] % groups, 0,
"ShapeError: The number of output channels of Op(conv) should be divided " platform::errors::InvalidArgument(
"by groups. " "The number of output channels of Op(conv) should be divided "
"But received: the output channels is [%d], the shape of filter is [%s] " "by groups. But received: the output channels is [%d], the shape "
"(the first dimension of filter is output channel), the groups is [%d].", "of filter is [%s] (the first dimension of filter is output "
filter_dims[0], filter_dims, groups); "channel), the groups is [%d].",
filter_dims[0], filter_dims, groups));
framework::DDim in_data_dims; framework::DDim in_data_dims;
framework::DDim filter_data_dims; framework::DDim filter_data_dims;
...@@ -169,11 +170,13 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -169,11 +170,13 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
input_data_type != framework::proto::VarType::UINT8) { input_data_type != framework::proto::VarType::UINT8) {
auto filter_data_type = ctx.Input<Tensor>("Filter")->type(); auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent"); platform::errors::InvalidArgument(
"input and filter data type should be consistent"));
} }
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN, PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used"); platform::errors::InvalidArgument(
"float16 can only be used when CUDNN is used"));
} }
auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
......
...@@ -29,12 +29,9 @@ namespace operators { ...@@ -29,12 +29,9 @@ namespace operators {
using DataLayout = framework::DataLayout; using DataLayout = framework::DataLayout;
void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ConvTranspose");
"Input(Input) of ConvTransposeOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "ConvTranspose");
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true, OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "ConvTranspose");
"Input(Filter) of ConvTransposeOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true,
"Output(Output) of ConvTransposeOp should not be null.");
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
...@@ -53,42 +50,47 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -53,42 +50,47 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
: framework::StringToDataLayout(data_layout_str); : framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true, PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
"ShapeError: input of Op(conv_transpose) should be 4-D or " platform::errors::InvalidArgument(
"5-D Tensor. But received: %u-D Tensor, " "Input of Op(conv_transpose) should be 4-D or "
"the shape of input is [%s]", "5-D Tensor. But received: %u-D Tensor, "
in_dims.size(), in_dims); "the shape of input is [%s]",
in_dims.size(), in_dims));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(), in_dims.size(), filter_dims.size(),
"ShapeError: the input's dimension size and filter's dimension size of " platform::errors::InvalidArgument(
"Op (conv_transpose) should be equal. But received: the shape of input " "The input's dimension size and filter's dimension size of "
"is [%s], the dimension size of input is [%d], the shape of filter is " "Op (conv_transpose) should be equal. But received: the shape of "
"[%s], the dimension size of filter is [%d]. ", "input is [%s], the dimension size of input is [%d], the shape "
in_dims, in_dims.size(), filter_dims, filter_dims.size()); "of filter is [%s], the dimension size of filter is [%d]. ",
in_dims, in_dims.size(), filter_dims, filter_dims.size()));
int in_sub_stride_size = in_dims.size() - strides.size(); int in_sub_stride_size = in_dims.size() - strides.size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size() - strides.size(), 2U, in_dims.size() - strides.size(), 2U,
"ShapeError: the input's dimension size minus Attr(stride)'s size must " platform::errors::InvalidArgument(
"be euqal to 2 for Op(conv_transpose). But received: [%d], the " "The input's dimension size minus Attr(stride)'s size must "
"input's dimension size is [%d], the shape of input " "be euqal to 2 for Op(conv_transpose). But received: [%d], the "
"is [%s], the Attr(stride)'s size is [%d].", "input's dimension size is [%d], the shape of input "
in_sub_stride_size, in_dims.size(), in_dims, strides.size()); "is [%s], the Attr(stride)'s size is [%d].",
in_sub_stride_size, in_dims.size(), in_dims, strides.size()));
if (output_size.size()) if (output_size.size())
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
output_size.size(), strides.size(), output_size.size(), strides.size(),
"The Attr(output_size) and Attr(stride) of Op(conv_transpose) " platform::errors::InvalidArgument(
"should be the same."); "The Attr(output_size) and Attr(stride) of Op(conv_transpose) "
"should be the same."));
const int64_t C = const int64_t C =
(data_layout != DataLayout::kNHWC ? in_dims[1] (data_layout != DataLayout::kNHWC ? in_dims[1]
: in_dims[in_dims.size() - 1]); : in_dims[in_dims.size() - 1]);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
C, filter_dims[0], C, filter_dims[0],
"ShapeError: The number of input channels should be equal to filter " platform::errors::InvalidArgument(
"channels for Op(conv_transpose). But received: the input's channels is " "The number of input channels should be equal to filter channels "
"[%d], the shape of input is [%s], the filter's channels is [%d], the " "for Op(conv_transpose). But received: the input's channels is "
"shape of filter is [%s]. The data_format is %s." "[%d], the shape of input is [%s], the filter's channels is [%d], "
"The error may come from wrong data_format setting.", "the shape of filter is [%s]. The data_format is %s."
C, in_dims, filter_dims[0], filter_dims, data_layout_str); "The error may come from wrong data_format setting.",
C, in_dims, filter_dims[0], filter_dims, data_layout_str));
framework::DDim in_data_dims; framework::DDim in_data_dims;
if (data_layout != DataLayout::kNHWC) { if (data_layout != DataLayout::kNHWC) {
......
...@@ -26,16 +26,11 @@ class CosSimOp : public framework::OperatorWithKernel { ...@@ -26,16 +26,11 @@ class CosSimOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
// notnull check // notnull check
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CosSim");
"Input(X) of CosSimOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "CosSim");
PADDLE_ENFORCE(ctx->HasInput("Y"), OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CosSim");
"Input(Y) of CosSimOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("XNorm"), "Output", "XNorm", "CosSim");
PADDLE_ENFORCE(ctx->HasOutput("Out"), OP_INOUT_CHECK(ctx->HasOutput("YNorm"), "Output", "YNorm", "CosSim");
"Output(Out) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XNorm"),
"Output(XNorm) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("YNorm"),
"Output(YNorm) of CosSimOp should not be null.");
// shape check // shape check
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
...@@ -48,19 +43,28 @@ class CosSimOp : public framework::OperatorWithKernel { ...@@ -48,19 +43,28 @@ class CosSimOp : public framework::OperatorWithKernel {
} }
if (check) { if (check) {
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_EQ(
"Ranks of Input(X) and Input(Y) must be equal."); x_dims.size(), y_dims.size(),
PADDLE_ENFORCE_GE(x_dims.size(), 2, platform::errors::InvalidArgument(
"Rank of Input(X) must not be less than 2."); "Ranks of Input(X) [%s] and Input(Y) [%s] must be equal.", x_dims,
y_dims));
PADDLE_ENFORCE_GE(
x_dims.size(), 2,
platform::errors::InvalidArgument(
"Rank of Input(X) %d must not be less than 2.", x_dims.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 1, x_dims.size()), framework::slice_ddim(x_dims, 1, x_dims.size()),
framework::slice_ddim(y_dims, 1, y_dims.size()), framework::slice_ddim(y_dims, 1, y_dims.size()),
"All dimensions except the 1st of Input(X) and Input(Y) " platform::errors::InvalidArgument(
"must be equal."); "All dimensions except the 1st of Input(X) [%s] and Input(Y) [%s]"
"must be equal.",
x_dims, y_dims));
PADDLE_ENFORCE( PADDLE_ENFORCE(
x_dims[0] == y_dims[0] || y_dims[0] == 1, x_dims[0] == y_dims[0] || y_dims[0] == 1,
"The 1st dimension of Input(Y) must be equal to Input(X) or" platform::errors::InvalidArgument(
" just 1 (which will be broadcasted to match Input(X))."); "The 1st dimension of Input(Y) %d must be equal to Input(X) %d or"
" just 1 (which will be broadcasted to match Input(X)).",
y_dims[0], x_dims[0]));
} }
// resize tensor // resize tensor
...@@ -116,13 +120,13 @@ class CosSimOpGrad : public framework::OperatorWithKernel { ...@@ -116,13 +120,13 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
// notnull check // notnull check
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CosSimGrad");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null."); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "CosSimGrad");
PADDLE_ENFORCE(ctx->HasInput("XNorm"), "Input(XNorm) must not be null."); OP_INOUT_CHECK(ctx->HasInput("XNorm"), "Input", "XNorm", "CosSimGrad");
PADDLE_ENFORCE(ctx->HasInput("YNorm"), "Input(YNorm) must not be null."); OP_INOUT_CHECK(ctx->HasInput("YNorm"), "Input", "YNorm", "CosSimGrad");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) must not be null."); OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "CosSimGrad");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Input(Out@GRAD) must not be null."); framework::GradVarName("Out"), "CosSimGrad");
// shape check // shape check
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
...@@ -133,26 +137,48 @@ class CosSimOpGrad : public framework::OperatorWithKernel { ...@@ -133,26 +137,48 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_GE(x_dims.size(), 2, "Ranks of Input(X) %d and Input(Y) %d must be equal.",
"Rank of Input(X) must not be less than 2."); x_dims.size(), y_dims.size()));
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()), PADDLE_ENFORCE_GE(
framework::slice_ddim(y_dims, 1, y_dims.size()), x_dims.size(), 2,
"All dimensions except the 1st of Input(X) and Input(Y) " platform::errors::InvalidArgument(
"must be equal."); "Rank of Input(X) %d must not be less than 2.", x_dims.size()));
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1, PADDLE_ENFORCE_EQ(
"The 1st dimension of Input(Y) must be equal to Input(X) or" framework::slice_ddim(x_dims, 1, x_dims.size()),
" just 1 (which will be broadcasted to match Input(X))."); framework::slice_ddim(y_dims, 1, y_dims.size()),
platform::errors::InvalidArgument(
"All dimensions except the 1st of Input(X) [%s] and Input(Y) [%s] "
"must be equal.",
x_dims, y_dims));
PADDLE_ENFORCE_EQ(
true, x_dims[0] == y_dims[0] || y_dims[0] == 1,
platform::errors::InvalidArgument(
"The 1st dimension of Input(Y) %d must be equal to Input(X) %d or"
" just 1 (which will be broadcasted to match Input(X)).",
y_dims[0], x_dims[0]));
auto target_xnorm_dims = framework::make_ddim({x_dims[0], 1}); auto target_xnorm_dims = framework::make_ddim({x_dims[0], 1});
auto target_ynorm_dims = framework::make_ddim({y_dims[0], 1}); auto target_ynorm_dims = framework::make_ddim({y_dims[0], 1});
PADDLE_ENFORCE_EQ(xnorm_dims, target_xnorm_dims, PADDLE_ENFORCE_EQ(
"Shape of Input(XNorm) must be [X.Dim(0), 1]."); xnorm_dims, target_xnorm_dims,
PADDLE_ENFORCE_EQ(ynorm_dims, target_ynorm_dims, platform::errors::InvalidArgument(
"Shape of Input(YNorm) must be [Y.Dim(0), 1]."); "Shape of Input(XNorm) [%s] must be (X.Dim(0), 1) - [%s]",
PADDLE_ENFORCE_EQ(out_dims, target_xnorm_dims, xnorm_dims, target_xnorm_dims));
"Shape of Input(Out) must be [X.Dim(0), 1]."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(out_grad_dims, target_xnorm_dims, ynorm_dims, target_ynorm_dims,
"Shape of Input(Out@Grad) must be [X.Dim(0), 1]."); platform::errors::InvalidArgument(
"Shape of Input(YNorm) [%s] must be (Y.Dim(0), 1) - [%s]",
ynorm_dims, target_ynorm_dims));
PADDLE_ENFORCE_EQ(
out_dims, target_xnorm_dims,
platform::errors::InvalidArgument(
"Shape of Input(Out) [%s] must be (X.Dim(0), 1) - [%s]", out_dims,
target_xnorm_dims));
PADDLE_ENFORCE_EQ(
out_grad_dims, target_xnorm_dims,
platform::errors::InvalidArgument(
"Shape of Input(Out@Grad) [%s] must be (X.Dim(0), 1) - [%s]",
out_grad_dims, target_xnorm_dims));
// resize tensor // resize tensor
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
......
...@@ -30,14 +30,12 @@ class GroupNormOp : public framework::OperatorWithKernel { ...@@ -30,14 +30,12 @@ class GroupNormOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GroupNorm");
"Input(X) of GroupNormOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "GroupNorm");
PADDLE_ENFORCE(ctx->HasOutput("Y"), OP_INOUT_CHECK(ctx->HasOutput("Mean"), "Output", "Mean", "GroupNorm");
"Output(Y) of GroupNormOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Variance"), "Output", "Variance",
PADDLE_ENFORCE(ctx->HasOutput("Mean"), "GroupNorm");
"Output(Mean) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Variance"),
"Output(Variance) of GroupNormOp should not be null.");
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
const std::string data_layout_str = const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_layout"); ctx->Attrs().Get<std::string>("data_layout");
...@@ -49,47 +47,52 @@ class GroupNormOp : public framework::OperatorWithKernel { ...@@ -49,47 +47,52 @@ class GroupNormOp : public framework::OperatorWithKernel {
auto groups = ctx->Attrs().Get<int>("groups"); auto groups = ctx->Attrs().Get<int>("groups");
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
groups, channel_num, groups, channel_num,
"ValueError: the Attr(groups) of Op(group_norm) must be less than or " platform::errors::InvalidArgument(
"equal to the number of channels. " "The Attr(groups) of Op(group_norm) must be less than or "
"But received: groups is [%s], channels is [%s], the Attr(data_layout) " "equal to the number of channels. But received: groups "
"is [%s]. The error may come from wrong data_layout setting.", "is [%s], channels is [%s], the Attr(data_layout) "
groups, channel_num, data_layout_str); "is [%s]. The error may come from wrong data_layout setting.",
groups, channel_num, data_layout_str));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
groups, 1, groups, 1,
"ValueError: the Attr(groups) of Op(group_norm) must be " platform::errors::InvalidArgument(
"greater than or equal to 1. But received: groups is [%s].", "The Attr(groups) of Op(group_norm) must be "
groups); "greater than or equal to 1. But received: groups is [%s].",
groups));
if (ctx->HasInput("Scale")) { if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Scale").size(), 1UL, ctx->GetInputDim("Scale").size(), 1UL,
"ShapeError: the Input(Scale) of Op(group_norm) should be 1-D " platform::errors::InvalidArgument(
"Tensor. " "The Input(Scale) of Op(group_norm) should be 1-D Tensor. "
"But received: %u-D Tensor, the shape of Input(Scale) is [%s].", "But received: %u-D Tensor, the shape of Input(Scale) is [%s].",
ctx->GetInputDim("Scale").size(), ctx->GetInputDim("Scale")); ctx->GetInputDim("Scale").size(), ctx->GetInputDim("Scale")));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Scale")[0], channel_num, ctx->GetInputDim("Scale")[0], channel_num,
"ShapeError: the Input(Scale)'s first dimension size of " platform::errors::InvalidArgument(
"Op(group_norm) must be equal to the number of channels. " "The Input(Scale)'s first dimension size of Op(group_norm) must "
"But received: the Input(Scale)'s first dimension size is [%s], the " "be equal to the number of channels. But received: the "
"channels is [%s], the Attr(data_layout) is [%s]. The error may come " "Input(Scale)'s first dimension size is [%s], the channels is "
"from wrong data_layout setting.", "[%s], the Attr(data_layout) is [%s]. The error may come "
ctx->GetInputDim("Scale")[0], channel_num, data_layout_str); "from wrong data_layout setting.",
ctx->GetInputDim("Scale")[0], channel_num, data_layout_str));
} }
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Bias").size(), 1UL, ctx->GetInputDim("Bias").size(), 1UL,
"ShapeError: the Input(Bias) of Op(group_norm) should be 1-D Tensor. " platform::errors::InvalidArgument(
"But received: %u-D Tensor, the shape of Input(Bias) is [%s].", "The Input(Bias) of Op(group_norm) should be 1-D Tensor. "
ctx->GetInputDim("Bias").size(), ctx->GetInputDim("Bias")); "But received: %u-D Tensor, the shape of Input(Bias) is [%s].",
ctx->GetInputDim("Bias").size(), ctx->GetInputDim("Bias")));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Bias")[0], channel_num, ctx->GetInputDim("Bias")[0], channel_num,
"ShapeError: the Input(Bias)'s first dimension size of " platform::errors::InvalidArgument(
"Op(group_norm) must be equal to the number of channels. " "The Input(Bias)'s first dimension size of "
"But received: the Input(Bias)'s first dimension size is [%s], the " "Op(group_norm) must be equal to the number of channels. "
"channels is [%s], the Attr(data_layout) is [%s]. The error may come " "But received: the Input(Bias)'s first dimension size is [%s], "
"from wrong data_layout setting.", "the channels is [%s], the Attr(data_layout) is [%s]. The "
ctx->GetInputDim("Bias")[0], channel_num, data_layout_str); "error may come from wrong data_layout setting.",
ctx->GetInputDim("Bias")[0], channel_num, data_layout_str));
} }
ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
...@@ -143,12 +146,11 @@ class GroupNormGradOp : public framework::OperatorWithKernel { ...@@ -143,12 +146,11 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
// check input // check input
PADDLE_ENFORCE(ctx->HasInput("Y"), OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "GroupNormGrad");
"Input(Y) of GroupNormOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance",
PADDLE_ENFORCE(ctx->HasInput("Variance"), "GroupNormGrad");
"Input(Variance) of GroupNormOp should not be null."); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), framework::GradVarName("Y"), "GroupNormGrad");
"Input(Y@GRAD) of GroupNormOp should not be null.");
// check output // check output
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
...@@ -168,18 +170,19 @@ class GroupNormGradOp : public framework::OperatorWithKernel { ...@@ -168,18 +170,19 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y")); const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
PADDLE_THROW("can't find Y@GRAD"); PADDLE_ENFORCE_NOT_NULL(
} var, platform::errors::InvalidArgument(
"Input(Y@GRAD) of GroupNormGradOp should not be null"));
const Tensor *t = nullptr; const Tensor *t = nullptr;
if (var->IsType<Tensor>()) { if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>(); t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) { } else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>(); t = &var->Get<LoDTensor>();
} }
if (t == nullptr) { PADDLE_ENFORCE_NOT_NULL(
PADDLE_THROW("can't find Y@GRAD"); t, platform::errors::InvalidArgument(
} "Input(Y@GRAD) Tensor of GroupNormGradOp should not be null"));
return framework::OpKernelType(t->type(), ctx.GetPlace()); return framework::OpKernelType(t->type(), ctx.GetPlace());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册