未验证 提交 5fe3b638 编写于 作者: H huzhiqiang 提交者: GitHub

[error message enhancement] fused_elemwise_activation_op and fusion_conv_inception_op (#23686)

上级 c4e6e206
......@@ -20,7 +20,11 @@ namespace paddle {
namespace operators {
bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
PADDLE_ENFORCE_EQ(
functor_list.size(), 2,
platform::errors::InvalidArgument(
"Invalid functor list size %d, which should be equal to %d.",
functor_list.size(), 2));
static std::unordered_set<std::string> binary_fun = {
"elementwise_add", "elementwise_mul", "elementwise_add_grad",
"elementwise_mul_grad"};
......@@ -28,7 +32,11 @@ bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
}
bool HasInPlaceUnary(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
PADDLE_ENFORCE_EQ(
functor_list.size(), 2,
platform::errors::InvalidArgument(
"Invalid functor list size %d, which should be equal to %d.",
functor_list.size(), 2));
static std::unordered_set<std::string> InplaceOpSet = {"relu", "relu_grad"};
bool is_in_place = false;
for (auto &func_name : functor_list) {
......@@ -38,7 +46,11 @@ bool HasInPlaceUnary(const std::vector<std::string> &functor_list) {
}
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
PADDLE_ENFORCE_EQ(
functor_list.size(), 2,
platform::errors::InvalidArgument(
"Invalid functor list size %d, which should be equal to %d.",
functor_list.size(), 2));
static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"};
return binary_fun.count(functor_list[0]) != 0 ||
binary_fun.count(functor_list[1]) != 0;
......@@ -50,7 +62,11 @@ bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
* out.
*/
static bool IsSupportedCompound(const std::vector<std::string> &functors) {
PADDLE_ENFORCE_EQ(functors.size(), 2UL);
PADDLE_ENFORCE_EQ(
functors.size(), 2UL,
platform::errors::InvalidArgument(
"Invalid functor list size %d, which should be equal to %d.",
functors.size(), 2));
static std::unordered_set<std::string> unary_fun = {"scale", "relu", "tanh",
"sigmoid"};
......@@ -63,11 +79,12 @@ static bool IsSupportedCompound(const std::vector<std::string> &functors) {
} else if (binary_fun.count(functors[1])) {
unary_fun_str = functors[0];
} else {
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0],
functors[1]);
PADDLE_THROW(platform::errors::InvalidArgument(
"%s and %s are not included in fused_list.", functors[0], functors[1]));
}
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
"%s is not included in fused_list.", unary_fun_str);
platform::errors::InvalidArgument(
"%s is not included in fused_list.", unary_fun_str));
return true;
}
......@@ -76,15 +93,18 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of FusedElemwiseActivationOp op should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("Y"),
"Input(Y) of FusedElemwiseActivationOp op should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FusedElemwiseActivationOp op should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of FusedElemwiseActivationOp op should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::InvalidArgument(
"Input(Y) of FusedElemwiseActivationOp op should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of FusedElemwiseActivationOp op should not be null."));
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
......@@ -97,9 +117,11 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
std::string out_lod = bcast_y ? "X" : "Y";
if (ctx->Attrs().Get<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"),
"Output(IntermediateOut) of FusedElemwiseActivationOp "
"should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasOutput("IntermediateOut"), true,
platform::errors::InvalidArgument(
"Output(IntermediateOut) of FusedElemwiseActivationOp "
"should not be null."));
if (IsUnaryCompound(
ctx->Attrs().Get<std::vector<std::string>>("functor_list"))) {
......@@ -139,7 +161,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(),
ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same.");
platform::errors::InvalidArgument(
"The element's type of input should be the same."));
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
......@@ -173,7 +196,10 @@ class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<std::string>>("functor_list",
"The functors that should be fused.")
.AddCustomChecker([&](const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE(IsSupportedCompound(functor_list));
PADDLE_ENFORCE_EQ(
IsSupportedCompound(functor_list), true,
platform::errors::InvalidArgument(
"the input functors should support compounding."));
});
AddComment(R"DOC(
......@@ -266,18 +292,22 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@Grad) should not be null."));
auto functor_list =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
if (ctx->Attrs().Get<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"),
"Input(IntermediateOut) should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput("IntermediateOut"), true,
platform::errors::InvalidArgument(
"Input(IntermediateOut) should not be null."));
} else {
if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should not be null."));
}
}
......@@ -292,9 +322,11 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
} else {
// Currently, only when Binary is elementwise_add or elementwise_sub,
// the "X" could be absent.
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
"Only when BinaryFunctor is elementwise_add, the 'X' "
"could be absent.");
PADDLE_ENFORCE_EQ(
InputXCanBeAbsent(functor_list), true,
platform::errors::InvalidArgument(
"Only when BinaryFunctor is elementwise_add, the 'X' "
"could be absent."));
// Node: If "X" is absence, the shape of Y should be a continuous
// subsequence of X, otherwise, we could not infer the shape of dx.
......@@ -306,7 +338,9 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
}
if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::InvalidArgument("Input(Y) should not be null."));
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", y_grad_name);
}
......
......@@ -32,10 +32,21 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel {
// 4 filters
auto w_dims = ctx->GetInputsDim("Filter");
PADDLE_ENFORCE(in_dims.size(), 4, "Conv intput should be 4-D tensor.");
PADDLE_ENFORCE_EQ(w_dims.size(), 4, "There should be 4 filters");
PADDLE_ENFORCE_EQ(w_dims[0][1], in_dims[1]);
PADDLE_ENFORCE_EQ(w_dims[1][1], in_dims[1]);
PADDLE_ENFORCE_EQ(
in_dims.size(), 4,
platform::errors::InvalidArgument("Conv intput should be 4-D tensor."));
PADDLE_ENFORCE_EQ(w_dims.size(), 4, platform::errors::InvalidArgument(
"There should be 4 filters."));
PADDLE_ENFORCE_EQ(w_dims[0][1], in_dims[1],
platform::errors::InvalidArgument(
"Invalid fileter channel number %d, which should be "
"equal to input channel number %d.",
w_dims[0][1], in_dims[1]));
PADDLE_ENFORCE_EQ(w_dims[1][1], in_dims[1],
platform::errors::InvalidArgument(
"Invalid fileter channel number %d, which should be "
"equal to input channel number %d.",
w_dims[1][1], in_dims[1]));
int n = in_dims[0];
// compute output channel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册