未验证 提交 6c332ad6 编写于 作者: L liuwei1031 提交者: GitHub

imporve error messages for conv, conv_transpose, cos_sim, group_norm (#23675)

* imporve error messages for conv, conv_transpose, cos_sim, group_norm
上级 05476e9f
......@@ -31,12 +31,9 @@ namespace paddle {
namespace operators {
void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input(Input) of ConvOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true,
"Input(Filter) of ConvOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true,
"Output(Output) of ConvOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Conv");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
......@@ -56,49 +53,53 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5, true,
"ShapeError: the input of Op(conv) should be 4-D or 5-D Tensor. But "
"received: %u-D Tensor, the shape of input is [%s].",
in_dims.size(), in_dims);
platform::errors::InvalidArgument(
"The input of Op(conv) should be 4-D or 5-D Tensor. But "
"received: %u-D Tensor, the shape of input is [%s].",
in_dims.size(), in_dims));
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
"ShapeError: the input's dimension size and filter's dimension size of "
"Op(conv) should be equal. But received: the shape of input is [%s], "
"the dimension size of input is [%d], the shape of filter is [%s], "
"the dimension size of filter is [%d].",
in_dims, in_dims.size(), filter_dims, filter_dims.size());
platform::errors::InvalidArgument(
"The input's dimension size and filter's dimension size of "
"Op(conv) should be equal. But received: the shape of input is [%s], "
"the dimension size of input is [%d], the shape of filter is [%s], "
"the dimension size of filter is [%d].",
in_dims, in_dims.size(), filter_dims, filter_dims.size()));
int in_sub_stride_size = in_dims.size() - strides.size();
PADDLE_ENFORCE_EQ(in_dims.size() - strides.size() == 2U, true,
"ShapeError: the dimension size of input minus the size of "
"Attr(stride) must be euqal to 2 for Op(conv)."
"But received: the dimension size of input minus the size "
"of Attr(stride) is [%d], the "
"input's dimension size is [%d], the shape of input "
"is [%s], the Attr(stride)'s size is [%d].",
in_sub_stride_size, in_dims.size(), in_dims,
strides.size());
PADDLE_ENFORCE_EQ(
in_dims.size(), strides.size() + 2U,
platform::errors::InvalidArgument(
"The dimension size of input minus the size of "
"Attr(stride) must be euqal to 2 for Op(conv)."
"But received: the dimension size of input minus the size "
"of Attr(stride) is [%d], the "
"input's dimension size is [%d], the shape of input "
"is [%s], the Attr(stride)'s size is [%d].",
in_sub_stride_size, in_dims.size(), in_dims, strides.size()));
const auto input_channels =
channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];
PADDLE_ENFORCE_EQ(
input_channels, filter_dims[1] * groups,
"ShapeError: The number of input channels should be equal to filter "
"channels * groups for Op(conv). But received: the input's channels is "
"[%d], the shape "
"of input is [%s], the filter's channel is [%d], the shape of filter is "
"[%s], the groups is [%d], the data_format is %s. The error may come "
"from wrong data_format setting.",
input_channels, in_dims, filter_dims[1], filter_dims, groups,
data_format);
platform::errors::InvalidArgument(
"The number of input channels should be equal to filter channels * "
"groups for Op(conv). But received: the input's channels is [%d], "
"the shape of input is [%s], the filter's channel is [%d], the shape "
"of filter is [%s], the groups is [%d], the data_format is %s. The "
"error may come from wrong data_format setting.",
input_channels, in_dims, filter_dims[1], filter_dims, groups,
data_format));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0,
"ShapeError: The number of output channels of Op(conv) should be divided "
"by groups. "
"But received: the output channels is [%d], the shape of filter is [%s] "
"(the first dimension of filter is output channel), the groups is [%d].",
filter_dims[0], filter_dims, groups);
platform::errors::InvalidArgument(
"The number of output channels of Op(conv) should be divided "
"by groups. But received: the output channels is [%d], the shape "
"of filter is [%s] (the first dimension of filter is output "
"channel), the groups is [%d].",
filter_dims[0], filter_dims, groups));
framework::DDim in_data_dims;
framework::DDim filter_data_dims;
......@@ -169,11 +170,13 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
input_data_type != framework::proto::VarType::UINT8) {
auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent");
platform::errors::InvalidArgument(
"input and filter data type should be consistent"));
}
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used");
platform::errors::InvalidArgument(
"float16 can only be used when CUDNN is used"));
}
auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
......
......@@ -29,12 +29,9 @@ namespace operators {
using DataLayout = framework::DataLayout;
void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input(Input) of ConvTransposeOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true,
"Input(Filter) of ConvTransposeOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true,
"Output(Output) of ConvTransposeOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ConvTranspose");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "ConvTranspose");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "ConvTranspose");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
......@@ -53,42 +50,47 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
: framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
"ShapeError: input of Op(conv_transpose) should be 4-D or "
"5-D Tensor. But received: %u-D Tensor, "
"the shape of input is [%s]",
in_dims.size(), in_dims);
platform::errors::InvalidArgument(
"Input of Op(conv_transpose) should be 4-D or "
"5-D Tensor. But received: %u-D Tensor, "
"the shape of input is [%s]",
in_dims.size(), in_dims));
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
"ShapeError: the input's dimension size and filter's dimension size of "
"Op (conv_transpose) should be equal. But received: the shape of input "
"is [%s], the dimension size of input is [%d], the shape of filter is "
"[%s], the dimension size of filter is [%d]. ",
in_dims, in_dims.size(), filter_dims, filter_dims.size());
platform::errors::InvalidArgument(
"The input's dimension size and filter's dimension size of "
"Op (conv_transpose) should be equal. But received: the shape of "
"input is [%s], the dimension size of input is [%d], the shape "
"of filter is [%s], the dimension size of filter is [%d]. ",
in_dims, in_dims.size(), filter_dims, filter_dims.size()));
int in_sub_stride_size = in_dims.size() - strides.size();
PADDLE_ENFORCE_EQ(
in_dims.size() - strides.size(), 2U,
"ShapeError: the input's dimension size minus Attr(stride)'s size must "
"be euqal to 2 for Op(conv_transpose). But received: [%d], the "
"input's dimension size is [%d], the shape of input "
"is [%s], the Attr(stride)'s size is [%d].",
in_sub_stride_size, in_dims.size(), in_dims, strides.size());
platform::errors::InvalidArgument(
"The input's dimension size minus Attr(stride)'s size must "
"be euqal to 2 for Op(conv_transpose). But received: [%d], the "
"input's dimension size is [%d], the shape of input "
"is [%s], the Attr(stride)'s size is [%d].",
in_sub_stride_size, in_dims.size(), in_dims, strides.size()));
if (output_size.size())
PADDLE_ENFORCE_EQ(
output_size.size(), strides.size(),
"The Attr(output_size) and Attr(stride) of Op(conv_transpose) "
"should be the same.");
platform::errors::InvalidArgument(
"The Attr(output_size) and Attr(stride) of Op(conv_transpose) "
"should be the same."));
const int64_t C =
(data_layout != DataLayout::kNHWC ? in_dims[1]
: in_dims[in_dims.size() - 1]);
PADDLE_ENFORCE_EQ(
C, filter_dims[0],
"ShapeError: The number of input channels should be equal to filter "
"channels for Op(conv_transpose). But received: the input's channels is "
"[%d], the shape of input is [%s], the filter's channels is [%d], the "
"shape of filter is [%s]. The data_format is %s."
"The error may come from wrong data_format setting.",
C, in_dims, filter_dims[0], filter_dims, data_layout_str);
platform::errors::InvalidArgument(
"The number of input channels should be equal to filter channels "
"for Op(conv_transpose). But received: the input's channels is "
"[%d], the shape of input is [%s], the filter's channels is [%d], "
"the shape of filter is [%s]. The data_format is %s."
"The error may come from wrong data_format setting.",
C, in_dims, filter_dims[0], filter_dims, data_layout_str));
framework::DDim in_data_dims;
if (data_layout != DataLayout::kNHWC) {
......
......@@ -26,16 +26,11 @@ class CosSimOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
// notnull check
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XNorm"),
"Output(XNorm) of CosSimOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("YNorm"),
"Output(YNorm) of CosSimOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CosSim");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "CosSim");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CosSim");
OP_INOUT_CHECK(ctx->HasOutput("XNorm"), "Output", "XNorm", "CosSim");
OP_INOUT_CHECK(ctx->HasOutput("YNorm"), "Output", "YNorm", "CosSim");
// shape check
auto x_dims = ctx->GetInputDim("X");
......@@ -48,19 +43,28 @@ class CosSimOp : public framework::OperatorWithKernel {
}
if (check) {
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Rank of Input(X) must not be less than 2.");
PADDLE_ENFORCE_EQ(
x_dims.size(), y_dims.size(),
platform::errors::InvalidArgument(
"Ranks of Input(X) [%s] and Input(Y) [%s] must be equal.", x_dims,
y_dims));
PADDLE_ENFORCE_GE(
x_dims.size(), 2,
platform::errors::InvalidArgument(
"Rank of Input(X) %d must not be less than 2.", x_dims.size()));
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 1, x_dims.size()),
framework::slice_ddim(y_dims, 1, y_dims.size()),
"All dimensions except the 1st of Input(X) and Input(Y) "
"must be equal.");
platform::errors::InvalidArgument(
"All dimensions except the 1st of Input(X) [%s] and Input(Y) [%s]"
"must be equal.",
x_dims, y_dims));
PADDLE_ENFORCE(
x_dims[0] == y_dims[0] || y_dims[0] == 1,
"The 1st dimension of Input(Y) must be equal to Input(X) or"
" just 1 (which will be broadcasted to match Input(X)).");
platform::errors::InvalidArgument(
"The 1st dimension of Input(Y) %d must be equal to Input(X) %d or"
" just 1 (which will be broadcasted to match Input(X)).",
y_dims[0], x_dims[0]));
}
// resize tensor
......@@ -116,13 +120,13 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
// notnull check
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("XNorm"), "Input(XNorm) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("YNorm"), "Input(YNorm) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CosSimGrad");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "CosSimGrad");
OP_INOUT_CHECK(ctx->HasInput("XNorm"), "Input", "XNorm", "CosSimGrad");
OP_INOUT_CHECK(ctx->HasInput("YNorm"), "Input", "YNorm", "CosSimGrad");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "CosSimGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "CosSimGrad");
// shape check
auto x_dims = ctx->GetInputDim("X");
......@@ -133,26 +137,48 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Rank of Input(X) must not be less than 2.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()),
framework::slice_ddim(y_dims, 1, y_dims.size()),
"All dimensions except the 1st of Input(X) and Input(Y) "
"must be equal.");
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
"The 1st dimension of Input(Y) must be equal to Input(X) or"
" just 1 (which will be broadcasted to match Input(X)).");
platform::errors::InvalidArgument(
"Ranks of Input(X) %d and Input(Y) %d must be equal.",
x_dims.size(), y_dims.size()));
PADDLE_ENFORCE_GE(
x_dims.size(), 2,
platform::errors::InvalidArgument(
"Rank of Input(X) %d must not be less than 2.", x_dims.size()));
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 1, x_dims.size()),
framework::slice_ddim(y_dims, 1, y_dims.size()),
platform::errors::InvalidArgument(
"All dimensions except the 1st of Input(X) [%s] and Input(Y) [%s] "
"must be equal.",
x_dims, y_dims));
PADDLE_ENFORCE_EQ(
true, x_dims[0] == y_dims[0] || y_dims[0] == 1,
platform::errors::InvalidArgument(
"The 1st dimension of Input(Y) %d must be equal to Input(X) %d or"
" just 1 (which will be broadcasted to match Input(X)).",
y_dims[0], x_dims[0]));
auto target_xnorm_dims = framework::make_ddim({x_dims[0], 1});
auto target_ynorm_dims = framework::make_ddim({y_dims[0], 1});
PADDLE_ENFORCE_EQ(xnorm_dims, target_xnorm_dims,
"Shape of Input(XNorm) must be [X.Dim(0), 1].");
PADDLE_ENFORCE_EQ(ynorm_dims, target_ynorm_dims,
"Shape of Input(YNorm) must be [Y.Dim(0), 1].");
PADDLE_ENFORCE_EQ(out_dims, target_xnorm_dims,
"Shape of Input(Out) must be [X.Dim(0), 1].");
PADDLE_ENFORCE_EQ(out_grad_dims, target_xnorm_dims,
"Shape of Input(Out@Grad) must be [X.Dim(0), 1].");
PADDLE_ENFORCE_EQ(
xnorm_dims, target_xnorm_dims,
platform::errors::InvalidArgument(
"Shape of Input(XNorm) [%s] must be (X.Dim(0), 1) - [%s]",
xnorm_dims, target_xnorm_dims));
PADDLE_ENFORCE_EQ(
ynorm_dims, target_ynorm_dims,
platform::errors::InvalidArgument(
"Shape of Input(YNorm) [%s] must be (Y.Dim(0), 1) - [%s]",
ynorm_dims, target_ynorm_dims));
PADDLE_ENFORCE_EQ(
out_dims, target_xnorm_dims,
platform::errors::InvalidArgument(
"Shape of Input(Out) [%s] must be (X.Dim(0), 1) - [%s]", out_dims,
target_xnorm_dims));
PADDLE_ENFORCE_EQ(
out_grad_dims, target_xnorm_dims,
platform::errors::InvalidArgument(
"Shape of Input(Out@Grad) [%s] must be (X.Dim(0), 1) - [%s]",
out_grad_dims, target_xnorm_dims));
// resize tensor
auto x_grad_name = framework::GradVarName("X");
......
......@@ -30,14 +30,12 @@ class GroupNormOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Mean"),
"Output(Mean) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Variance"),
"Output(Variance) of GroupNormOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GroupNorm");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "GroupNorm");
OP_INOUT_CHECK(ctx->HasOutput("Mean"), "Output", "Mean", "GroupNorm");
OP_INOUT_CHECK(ctx->HasOutput("Variance"), "Output", "Variance",
"GroupNorm");
auto x_dim = ctx->GetInputDim("X");
const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_layout");
......@@ -49,47 +47,52 @@ class GroupNormOp : public framework::OperatorWithKernel {
auto groups = ctx->Attrs().Get<int>("groups");
PADDLE_ENFORCE_LE(
groups, channel_num,
"ValueError: the Attr(groups) of Op(group_norm) must be less than or "
"equal to the number of channels. "
"But received: groups is [%s], channels is [%s], the Attr(data_layout) "
"is [%s]. The error may come from wrong data_layout setting.",
groups, channel_num, data_layout_str);
platform::errors::InvalidArgument(
"The Attr(groups) of Op(group_norm) must be less than or "
"equal to the number of channels. But received: groups "
"is [%s], channels is [%s], the Attr(data_layout) "
"is [%s]. The error may come from wrong data_layout setting.",
groups, channel_num, data_layout_str));
PADDLE_ENFORCE_GE(
groups, 1,
"ValueError: the Attr(groups) of Op(group_norm) must be "
"greater than or equal to 1. But received: groups is [%s].",
groups);
platform::errors::InvalidArgument(
"The Attr(groups) of Op(group_norm) must be "
"greater than or equal to 1. But received: groups is [%s].",
groups));
if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Scale").size(), 1UL,
"ShapeError: the Input(Scale) of Op(group_norm) should be 1-D "
"Tensor. "
"But received: %u-D Tensor, the shape of Input(Scale) is [%s].",
ctx->GetInputDim("Scale").size(), ctx->GetInputDim("Scale"));
platform::errors::InvalidArgument(
"The Input(Scale) of Op(group_norm) should be 1-D Tensor. "
"But received: %u-D Tensor, the shape of Input(Scale) is [%s].",
ctx->GetInputDim("Scale").size(), ctx->GetInputDim("Scale")));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Scale")[0], channel_num,
"ShapeError: the Input(Scale)'s first dimension size of "
"Op(group_norm) must be equal to the number of channels. "
"But received: the Input(Scale)'s first dimension size is [%s], the "
"channels is [%s], the Attr(data_layout) is [%s]. The error may come "
"from wrong data_layout setting.",
ctx->GetInputDim("Scale")[0], channel_num, data_layout_str);
platform::errors::InvalidArgument(
"The Input(Scale)'s first dimension size of Op(group_norm) must "
"be equal to the number of channels. But received: the "
"Input(Scale)'s first dimension size is [%s], the channels is "
"[%s], the Attr(data_layout) is [%s]. The error may come "
"from wrong data_layout setting.",
ctx->GetInputDim("Scale")[0], channel_num, data_layout_str));
}
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Bias").size(), 1UL,
"ShapeError: the Input(Bias) of Op(group_norm) should be 1-D Tensor. "
"But received: %u-D Tensor, the shape of Input(Bias) is [%s].",
ctx->GetInputDim("Bias").size(), ctx->GetInputDim("Bias"));
platform::errors::InvalidArgument(
"The Input(Bias) of Op(group_norm) should be 1-D Tensor. "
"But received: %u-D Tensor, the shape of Input(Bias) is [%s].",
ctx->GetInputDim("Bias").size(), ctx->GetInputDim("Bias")));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Bias")[0], channel_num,
"ShapeError: the Input(Bias)'s first dimension size of "
"Op(group_norm) must be equal to the number of channels. "
"But received: the Input(Bias)'s first dimension size is [%s], the "
"channels is [%s], the Attr(data_layout) is [%s]. The error may come "
"from wrong data_layout setting.",
ctx->GetInputDim("Bias")[0], channel_num, data_layout_str);
platform::errors::InvalidArgument(
"The Input(Bias)'s first dimension size of "
"Op(group_norm) must be equal to the number of channels. "
"But received: the Input(Bias)'s first dimension size is [%s], "
"the channels is [%s], the Attr(data_layout) is [%s]. The "
"error may come from wrong data_layout setting.",
ctx->GetInputDim("Bias")[0], channel_num, data_layout_str));
}
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
......@@ -143,12 +146,11 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
// check input
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
"Input(Variance) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) of GroupNormOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "GroupNormGrad");
OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance",
"GroupNormGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
framework::GradVarName("Y"), "GroupNormGrad");
// check output
if (ctx->HasOutput(framework::GradVarName("X"))) {
......@@ -168,18 +170,19 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument(
"Input(Y@GRAD) of GroupNormGradOp should not be null"));
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
PADDLE_ENFORCE_NOT_NULL(
t, platform::errors::InvalidArgument(
"Input(Y@GRAD) Tensor of GroupNormGradOp should not be null"));
return framework::OpKernelType(t->type(), ctx.GetPlace());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册