未验证 提交 18549417 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] fix batch_norm op infermeta bug (#47858)

上级 17dffd13
......@@ -164,6 +164,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
if (ctx->HasInput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", {-1});
}
}
framework::OpKernelType BatchNormOp::GetExpectedKernelType(
......
......@@ -40,10 +40,11 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
PADDLE_ENFORCE_LE(
axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
if (x_dims.size() > y_dims.size()) {
......
......@@ -123,10 +123,11 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
PADDLE_ENFORCE_LE(
axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
......
......@@ -640,6 +640,9 @@ void BatchNormInferMeta(const MetaTensor& x,
if (saved_variance) {
saved_variance->set_dims({C});
}
if (reserve_space) {
reserve_space->set_dims({-1});
}
y->share_lod(x);
y->set_dtype(x.dtype());
}
......
......@@ -45,10 +45,11 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
PADDLE_ENFORCE_LE(
axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
if (x_dims.size() > y_dims.size()) {
......
......@@ -326,10 +326,11 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx,
phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
PADDLE_ENFORCE_LE(
axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
std::vector<int> x_dims_array(max_dim);
......@@ -394,10 +395,11 @@ void ElementwiseCompute(const CPUContext &dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
......
......@@ -287,10 +287,11 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
......@@ -1725,10 +1726,11 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
......
......@@ -51,10 +51,11 @@ void XPUElementwise(const XPUContext& dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
......@@ -121,10 +122,11 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册