未验证 提交 dc439a12 编写于 作者: Z Zhong Hui 提交者: GitHub

Enhance tensor shape check for dist op. (#34915)

上级 fd92d949
...@@ -27,6 +27,20 @@ class DistOp : public framework::OperatorWithKernel { ...@@ -27,6 +27,20 @@ class DistOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Dist"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Dist");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Dist"); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Dist");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Dist"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Dist");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_NE(framework::product(x_dims), 0,
platform::errors::InvalidArgument(
"The Input(X) has not been initialized properly. The "
"shape of Input(X) = [%s].",
x_dims));
PADDLE_ENFORCE_NE(framework::product(y_dims), 0,
platform::errors::InvalidArgument(
"The Input(Y) has not been initialized properly. The "
"shape of Input(Y) = [%s].",
y_dims));
ctx->SetOutputDim("Out", {1}); ctx->SetOutputDim("Out", {1});
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册