未验证 提交 18549417 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] fix batch_norm op infermeta bug (#47858)

上级 17dffd13
......@@ -164,6 +164,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
if (ctx->HasInput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", {-1});
}
}
framework::OpKernelType BatchNormOp::GetExpectedKernelType(
......
......@@ -40,12 +40,13 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) {
......
......@@ -123,12 +123,13 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
for (int i = 0; i < x_dims.size(); ++i) {
dst_dims_vec[i + x_axis] =
......
......@@ -640,6 +640,9 @@ void BatchNormInferMeta(const MetaTensor& x,
if (saved_variance) {
saved_variance->set_dims({C});
}
if (reserve_space) {
reserve_space->set_dims({-1});
}
y->share_lod(x);
y->set_dtype(x.dtype());
}
......
......@@ -45,12 +45,13 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) {
......
......@@ -326,12 +326,13 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx,
phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
......@@ -394,12 +395,13 @@ void ElementwiseCompute(const CPUContext &dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
......
......@@ -287,12 +287,13 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
......@@ -1725,12 +1726,13 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
......
......@@ -51,12 +51,13 @@ void XPUElementwise(const XPUContext& dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
......@@ -121,12 +122,13 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册