未验证 提交 18549417 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] fix batch_norm op infermeta bug (#47858)

上级 17dffd13
...@@ -164,6 +164,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -164,6 +164,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C}); ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y"); ctx->ShareLoD("X", "Y");
if (ctx->HasInput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", {-1});
}
} }
framework::OpKernelType BatchNormOp::GetExpectedKernelType( framework::OpKernelType BatchNormOp::GetExpectedKernelType(
......
...@@ -40,12 +40,13 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, ...@@ -40,12 +40,13 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
max_dim, axis,
platform::errors::InvalidArgument( max_dim,
"Axis should be less than %d, but received axis is %d.", platform::errors::InvalidArgument(
max_dim, "Axis should be less than or equal to %d, but received axis is %d.",
axis)); max_dim,
axis));
if (x_dims.size() > y_dims.size()) { if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1); std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) { if (axis + y_dims.size() < max_dim) {
......
...@@ -123,12 +123,13 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx, ...@@ -123,12 +123,13 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
max_dim, axis,
platform::errors::InvalidArgument( max_dim,
"Axis should be less than %d, but received axis is %d.", platform::errors::InvalidArgument(
max_dim, "Axis should be less than or equal to %d, but received axis is %d.",
axis)); max_dim,
axis));
for (int i = 0; i < x_dims.size(); ++i) { for (int i = 0; i < x_dims.size(); ++i) {
dst_dims_vec[i + x_axis] = dst_dims_vec[i + x_axis] =
......
...@@ -640,6 +640,9 @@ void BatchNormInferMeta(const MetaTensor& x, ...@@ -640,6 +640,9 @@ void BatchNormInferMeta(const MetaTensor& x,
if (saved_variance) { if (saved_variance) {
saved_variance->set_dims({C}); saved_variance->set_dims({C});
} }
if (reserve_space) {
reserve_space->set_dims({-1});
}
y->share_lod(x); y->share_lod(x);
y->set_dtype(x.dtype()); y->set_dtype(x.dtype());
} }
......
...@@ -45,12 +45,13 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, ...@@ -45,12 +45,13 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
max_dim, axis,
phi::errors::InvalidArgument( max_dim,
"Axis should be less than %d, but received axis is %d.", phi::errors::InvalidArgument(
max_dim, "Axis should be less than or equal to %d, but received axis is %d.",
axis)); max_dim,
axis));
if (x_dims.size() > y_dims.size()) { if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1); std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) { if (axis + y_dims.size() < max_dim) {
......
...@@ -326,12 +326,13 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx, ...@@ -326,12 +326,13 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
max_dim, axis,
phi::errors::InvalidArgument( max_dim,
"Axis should be less than %d, but received axis is %d.", phi::errors::InvalidArgument(
max_dim, "Axis should be less than or equal to %d, but received axis is %d.",
axis)); max_dim,
axis));
std::vector<int> x_dims_array(max_dim); std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim); std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim); std::vector<int> out_dims_array(max_dim);
...@@ -394,12 +395,13 @@ void ElementwiseCompute(const CPUContext &dev_ctx, ...@@ -394,12 +395,13 @@ void ElementwiseCompute(const CPUContext &dev_ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
max_dim, axis,
errors::InvalidArgument( max_dim,
"Axis should be less than %d, but received axis is %d.", errors::InvalidArgument(
max_dim, "Axis should be less than or equal to %d, but received axis is %d.",
axis)); max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0; int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) { if (is_xsize_larger) {
......
...@@ -287,12 +287,13 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx, ...@@ -287,12 +287,13 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
max_dim, axis,
errors::InvalidArgument( max_dim,
"Axis should be less than %d, but received axis is %d.", errors::InvalidArgument(
max_dim, "Axis should be less than or equal to %d, but received axis is %d.",
axis)); max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0; int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) { if (is_xsize_larger) {
...@@ -1725,12 +1726,13 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, ...@@ -1725,12 +1726,13 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
max_dim, axis,
errors::InvalidArgument( max_dim,
"Axis should be less than %d, but received axis is %d.", errors::InvalidArgument(
max_dim, "Axis should be less than or equal to %d, but received axis is %d.",
axis)); max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0; int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) { if (is_xsize_larger) {
......
...@@ -51,12 +51,13 @@ void XPUElementwise(const XPUContext& dev_ctx, ...@@ -51,12 +51,13 @@ void XPUElementwise(const XPUContext& dev_ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
max_dim, axis,
errors::InvalidArgument( max_dim,
"Axis should be less than %d, but received axis is %d.", errors::InvalidArgument(
max_dim, "Axis should be less than or equal to %d, but received axis is %d.",
axis)); max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1); std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1); std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) { if (x_dims.size() == max_dim) {
...@@ -121,12 +122,13 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx, ...@@ -121,12 +122,13 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx,
errors::InvalidArgument( errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.", "Axis should be great than or equal to 0, but received axis is %d.",
axis)); axis));
PADDLE_ENFORCE_LE(axis, PADDLE_ENFORCE_LE(
max_dim, axis,
errors::InvalidArgument( max_dim,
"Axis should be less than %d, but received axis is %d.", errors::InvalidArgument(
max_dim, "Axis should be less than or equal to %d, but received axis is %d.",
axis)); max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1); std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1); std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) { if (x_dims.size() == max_dim) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册