未验证 提交 18549417 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] fix batch_norm op infermeta bug (#47858)

上级 17dffd13
...@@ -164,6 +164,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -164,6 +164,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C}); ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y"); ctx->ShareLoD("X", "Y");
if (ctx->HasInput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", {-1});
}
} }
framework::OpKernelType BatchNormOp::GetExpectedKernelType( framework::OpKernelType BatchNormOp::GetExpectedKernelType(
......
...@@ -40,10 +40,11 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, ...@@ -40,10 +40,11 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
axis,
max_dim, max_dim,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than or equal to %d, but received axis is %d.",
max_dim, max_dim,
axis)); axis));
if (x_dims.size() > y_dims.size()) { if (x_dims.size() > y_dims.size()) {
......
...@@ -123,10 +123,11 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx, ...@@ -123,10 +123,11 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
axis,
max_dim, max_dim,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than or equal to %d, but received axis is %d.",
max_dim, max_dim,
axis)); axis));
......
...@@ -640,6 +640,9 @@ void BatchNormInferMeta(const MetaTensor& x, ...@@ -640,6 +640,9 @@ void BatchNormInferMeta(const MetaTensor& x,
if (saved_variance) { if (saved_variance) {
saved_variance->set_dims({C}); saved_variance->set_dims({C});
} }
if (reserve_space) {
reserve_space->set_dims({-1});
}
y->share_lod(x); y->share_lod(x);
y->set_dtype(x.dtype()); y->set_dtype(x.dtype());
} }
......
...@@ -45,10 +45,11 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, ...@@ -45,10 +45,11 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
axis,
max_dim, max_dim,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than or equal to %d, but received axis is %d.",
max_dim, max_dim,
axis)); axis));
if (x_dims.size() > y_dims.size()) { if (x_dims.size() > y_dims.size()) {
......
...@@ -326,10 +326,11 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx, ...@@ -326,10 +326,11 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
axis,
max_dim, max_dim,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than or equal to %d, but received axis is %d.",
max_dim, max_dim,
axis)); axis));
std::vector<int> x_dims_array(max_dim); std::vector<int> x_dims_array(max_dim);
...@@ -394,10 +395,11 @@ void ElementwiseCompute(const CPUContext &dev_ctx, ...@@ -394,10 +395,11 @@ void ElementwiseCompute(const CPUContext &dev_ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
axis,
max_dim, max_dim,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than or equal to %d, but received axis is %d.",
max_dim, max_dim,
axis)); axis));
......
...@@ -287,10 +287,11 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx, ...@@ -287,10 +287,11 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
axis,
max_dim, max_dim,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than or equal to %d, but received axis is %d.",
max_dim, max_dim,
axis)); axis));
...@@ -1725,10 +1726,11 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, ...@@ -1725,10 +1726,11 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
axis,
max_dim, max_dim,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than or equal to %d, but received axis is %d.",
max_dim, max_dim,
axis)); axis));
......
...@@ -51,10 +51,11 @@ void XPUElementwise(const XPUContext& dev_ctx, ...@@ -51,10 +51,11 @@ void XPUElementwise(const XPUContext& dev_ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
axis,
max_dim, max_dim,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than or equal to %d, but received axis is %d.",
max_dim, max_dim,
axis)); axis));
std::vector<int> x_dims_vec(max_dim, 1); std::vector<int> x_dims_vec(max_dim, 1);
...@@ -121,10 +122,11 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx, ...@@ -121,10 +122,11 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
axis,
max_dim, max_dim,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", "Axis should be less than or equal to %d, but received axis is %d.",
max_dim, max_dim,
axis)); axis));
std::vector<int> x_dims_vec(max_dim, 1); std::vector<int> x_dims_vec(max_dim, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册