未验证 提交 ca7f5208 编写于 作者: C ceci3 提交者: GitHub

fix batch_norm and instance norm when input is [] (#34107)

* fix batch_norm and instance norm when input is []
上级 255fc7d8
...@@ -55,6 +55,16 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -55,6 +55,16 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
"Variance and VarianceOut should share the same memory")); "Variance and VarianceOut should share the same memory"));
const auto x_dims = ctx->GetInputDim("X"); const auto x_dims = ctx->GetInputDim("X");
for (int i = 0; i < x_dims.size(); i++) {
PADDLE_ENFORCE_EQ(
(x_dims[i] == -1) || (x_dims[i] > 0), true,
platform::errors::InvalidArgument(
"Each dimension of input tensor is expected to be -1 or a "
"positive number, but recieved %d. Input's shape is [%s].",
x_dims[i], x_dims));
}
const DataLayout data_layout = framework::StringToDataLayout( const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout")); ctx->Attrs().Get<std::string>("data_layout"));
......
...@@ -32,6 +32,13 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -32,6 +32,13 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const {
"InstanceNorm"); "InstanceNorm");
const auto x_dims = ctx->GetInputDim("X"); const auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_NE(framework::product(x_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable X(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("X").front()));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
x_dims.size(), 2, x_dims.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册