提交 285a0b7b 编写于 作者: C ceci3

test=develop

上级 f8c764c0
......@@ -68,6 +68,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
auto scale_dim = ctx->GetInputDim("Scale");
auto bias_dim = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
framework::product(bias_dim) <= 0)) {
......@@ -75,9 +78,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
}
if (check) {
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
PADDLE_ENFORCE_EQ(scale_dim[0], C);
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
PADDLE_ENFORCE_EQ(scale_dim[0], C);
}
ctx->SetOutputDim("Y", x_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册