未验证 提交 5fe3b638 编写于 作者: H huzhiqiang 提交者: GitHub

[error message enhancement] fused_elemwise_activation_op and fusion_conv_inception_op (#23686)

上级 c4e6e206
...@@ -20,7 +20,11 @@ namespace paddle { ...@@ -20,7 +20,11 @@ namespace paddle {
namespace operators { namespace operators {
bool IsUnaryCompound(const std::vector<std::string> &functor_list) { bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2); PADDLE_ENFORCE_EQ(
functor_list.size(), 2,
platform::errors::InvalidArgument(
"Invalid functor list size %d, which should be equal to %d.",
functor_list.size(), 2));
static std::unordered_set<std::string> binary_fun = { static std::unordered_set<std::string> binary_fun = {
"elementwise_add", "elementwise_mul", "elementwise_add_grad", "elementwise_add", "elementwise_mul", "elementwise_add_grad",
"elementwise_mul_grad"}; "elementwise_mul_grad"};
...@@ -28,7 +32,11 @@ bool IsUnaryCompound(const std::vector<std::string> &functor_list) { ...@@ -28,7 +32,11 @@ bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
} }
bool HasInPlaceUnary(const std::vector<std::string> &functor_list) { bool HasInPlaceUnary(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2); PADDLE_ENFORCE_EQ(
functor_list.size(), 2,
platform::errors::InvalidArgument(
"Invalid functor list size %d, which should be equal to %d.",
functor_list.size(), 2));
static std::unordered_set<std::string> InplaceOpSet = {"relu", "relu_grad"}; static std::unordered_set<std::string> InplaceOpSet = {"relu", "relu_grad"};
bool is_in_place = false; bool is_in_place = false;
for (auto &func_name : functor_list) { for (auto &func_name : functor_list) {
...@@ -38,7 +46,11 @@ bool HasInPlaceUnary(const std::vector<std::string> &functor_list) { ...@@ -38,7 +46,11 @@ bool HasInPlaceUnary(const std::vector<std::string> &functor_list) {
} }
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) { bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2); PADDLE_ENFORCE_EQ(
functor_list.size(), 2,
platform::errors::InvalidArgument(
"Invalid functor list size %d, which should be equal to %d.",
functor_list.size(), 2));
static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"}; static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"};
return binary_fun.count(functor_list[0]) != 0 || return binary_fun.count(functor_list[0]) != 0 ||
binary_fun.count(functor_list[1]) != 0; binary_fun.count(functor_list[1]) != 0;
...@@ -50,7 +62,11 @@ bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) { ...@@ -50,7 +62,11 @@ bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
* out. * out.
*/ */
static bool IsSupportedCompound(const std::vector<std::string> &functors) { static bool IsSupportedCompound(const std::vector<std::string> &functors) {
PADDLE_ENFORCE_EQ(functors.size(), 2UL); PADDLE_ENFORCE_EQ(
functors.size(), 2UL,
platform::errors::InvalidArgument(
"Invalid functor list size %d, which should be equal to %d.",
functors.size(), 2));
static std::unordered_set<std::string> unary_fun = {"scale", "relu", "tanh", static std::unordered_set<std::string> unary_fun = {"scale", "relu", "tanh",
"sigmoid"}; "sigmoid"};
...@@ -63,11 +79,12 @@ static bool IsSupportedCompound(const std::vector<std::string> &functors) { ...@@ -63,11 +79,12 @@ static bool IsSupportedCompound(const std::vector<std::string> &functors) {
} else if (binary_fun.count(functors[1])) { } else if (binary_fun.count(functors[1])) {
unary_fun_str = functors[0]; unary_fun_str = functors[0];
} else { } else {
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0], PADDLE_THROW(platform::errors::InvalidArgument(
functors[1]); "%s and %s are not included in fused_list.", functors[0], functors[1]));
} }
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1, PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
"%s is not included in fused_list.", unary_fun_str); platform::errors::InvalidArgument(
"%s is not included in fused_list.", unary_fun_str));
return true; return true;
} }
...@@ -76,15 +93,18 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -76,15 +93,18 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), ctx->HasInput("X"), true,
"Input(X) of FusedElemwiseActivationOp op should not be null."); platform::errors::InvalidArgument(
PADDLE_ENFORCE( "Input(X) of FusedElemwiseActivationOp op should not be null."));
ctx->HasInput("Y"), PADDLE_ENFORCE_EQ(
"Input(Y) of FusedElemwiseActivationOp op should not be null."); ctx->HasInput("Y"), true,
PADDLE_ENFORCE( platform::errors::InvalidArgument(
ctx->HasOutput("Out"), "Input(Y) of FusedElemwiseActivationOp op should not be null."));
"Output(Out) of FusedElemwiseActivationOp op should not be null."); PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of FusedElemwiseActivationOp op should not be null."));
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
...@@ -97,9 +117,11 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -97,9 +117,11 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
std::string out_lod = bcast_y ? "X" : "Y"; std::string out_lod = bcast_y ? "X" : "Y";
if (ctx->Attrs().Get<bool>("save_intermediate_out")) { if (ctx->Attrs().Get<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"), PADDLE_ENFORCE_EQ(
"Output(IntermediateOut) of FusedElemwiseActivationOp " ctx->HasOutput("IntermediateOut"), true,
"should not be null."); platform::errors::InvalidArgument(
"Output(IntermediateOut) of FusedElemwiseActivationOp "
"should not be null."));
if (IsUnaryCompound( if (IsUnaryCompound(
ctx->Attrs().Get<std::vector<std::string>>("functor_list"))) { ctx->Attrs().Get<std::vector<std::string>>("functor_list"))) {
...@@ -139,7 +161,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -139,7 +161,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(), PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(),
ctx.Input<framework::Tensor>("Y")->type(), ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same."); platform::errors::InvalidArgument(
"The element's type of input should be the same."));
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
...@@ -173,7 +196,10 @@ class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker { ...@@ -173,7 +196,10 @@ class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<std::string>>("functor_list", AddAttr<std::vector<std::string>>("functor_list",
"The functors that should be fused.") "The functors that should be fused.")
.AddCustomChecker([&](const std::vector<std::string> &functor_list) { .AddCustomChecker([&](const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE(IsSupportedCompound(functor_list)); PADDLE_ENFORCE_EQ(
IsSupportedCompound(functor_list), true,
platform::errors::InvalidArgument(
"the input functors should support compounding."));
}); });
AddComment(R"DOC( AddComment(R"DOC(
...@@ -266,18 +292,22 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { ...@@ -266,18 +292,22 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@Grad) should not be null"); platform::errors::InvalidArgument(
"Input(Out@Grad) should not be null."));
auto functor_list = auto functor_list =
ctx->Attrs().Get<std::vector<std::string>>("functor_list"); ctx->Attrs().Get<std::vector<std::string>>("functor_list");
if (ctx->Attrs().Get<bool>("save_intermediate_out")) { if (ctx->Attrs().Get<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"), PADDLE_ENFORCE_EQ(ctx->HasInput("IntermediateOut"), true,
"Input(IntermediateOut) should not be null"); platform::errors::InvalidArgument(
"Input(IntermediateOut) should not be null."));
} else { } else {
if (!InputXCanBeAbsent(functor_list)) { if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should not be null."));
} }
} }
...@@ -292,9 +322,11 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { ...@@ -292,9 +322,11 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
} else { } else {
// Currently, only when Binary is elementwise_add or elementwise_sub, // Currently, only when Binary is elementwise_add or elementwise_sub,
// the "X" could be absent. // the "X" could be absent.
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list), PADDLE_ENFORCE_EQ(
"Only when BinaryFunctor is elementwise_add, the 'X' " InputXCanBeAbsent(functor_list), true,
"could be absent."); platform::errors::InvalidArgument(
"Only when BinaryFunctor is elementwise_add, the 'X' "
"could be absent."));
// Node: If "X" is absence, the shape of Y should be a continuous // Node: If "X" is absence, the shape of Y should be a continuous
// subsequence of X, otherwise, we could not infer the shape of dx. // subsequence of X, otherwise, we could not infer the shape of dx.
...@@ -306,7 +338,9 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { ...@@ -306,7 +338,9 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::InvalidArgument("Input(Y) should not be null."));
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y")); ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", y_grad_name); ctx->ShareLoD("Y", y_grad_name);
} }
......
...@@ -32,10 +32,21 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel { ...@@ -32,10 +32,21 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel {
// 4 filters // 4 filters
auto w_dims = ctx->GetInputsDim("Filter"); auto w_dims = ctx->GetInputsDim("Filter");
PADDLE_ENFORCE(in_dims.size(), 4, "Conv intput should be 4-D tensor."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(w_dims.size(), 4, "There should be 4 filters"); in_dims.size(), 4,
PADDLE_ENFORCE_EQ(w_dims[0][1], in_dims[1]); platform::errors::InvalidArgument("Conv intput should be 4-D tensor."));
PADDLE_ENFORCE_EQ(w_dims[1][1], in_dims[1]); PADDLE_ENFORCE_EQ(w_dims.size(), 4, platform::errors::InvalidArgument(
"There should be 4 filters."));
PADDLE_ENFORCE_EQ(w_dims[0][1], in_dims[1],
platform::errors::InvalidArgument(
"Invalid fileter channel number %d, which should be "
"equal to input channel number %d.",
w_dims[0][1], in_dims[1]));
PADDLE_ENFORCE_EQ(w_dims[1][1], in_dims[1],
platform::errors::InvalidArgument(
"Invalid fileter channel number %d, which should be "
"equal to input channel number %d.",
w_dims[1][1], in_dims[1]));
int n = in_dims[0]; int n = in_dims[0];
// compute output channel // compute output channel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册