提交 87d89dfe 编写于 作者: C ceci3

fix batch_norm and cos_sim infer shape, test=develop

上级 c79cdf25
...@@ -65,11 +65,21 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -65,11 +65,21 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
(data_layout == DataLayout::kNCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); auto scale_dim = ctx->GetInputDim("Scale");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C); auto bias_dim = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], C);
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
framework::product(bias_dim) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
PADDLE_ENFORCE_EQ(scale_dim[0], C);
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
PADDLE_ENFORCE_EQ(scale_dim[0], C);
}
ctx->SetOutputDim("Y", x_dims); ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("MeanOut", {C}); ctx->SetOutputDim("MeanOut", {C});
ctx->SetOutputDim("VarianceOut", {C}); ctx->SetOutputDim("VarianceOut", {C});
......
...@@ -40,17 +40,27 @@ class CosSimOp : public framework::OperatorWithKernel { ...@@ -40,17 +40,27 @@ class CosSimOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
bool check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal."); "Ranks of Input(X) and Input(Y) must be equal.");
PADDLE_ENFORCE_GE(x_dims.size(), 2, PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Rank of Input(X) must not be less than 2."); "Rank of Input(X) must not be less than 2.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()), PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, 1, x_dims.size()),
framework::slice_ddim(y_dims, 1, y_dims.size()), framework::slice_ddim(y_dims, 1, y_dims.size()),
"All dimensions except the 1st of Input(X) and Input(Y) " "All dimensions except the 1st of Input(X) and Input(Y) "
"must be equal."); "must be equal.");
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1, PADDLE_ENFORCE(
x_dims[0] == y_dims[0] || y_dims[0] == 1,
"The 1st dimension of Input(Y) must be equal to Input(X) or" "The 1st dimension of Input(Y) must be equal to Input(X) or"
" just 1 (which will be broadcasted to match Input(X))."); " just 1 (which will be broadcasted to match Input(X)).");
}
// resize tensor // resize tensor
ctx->SetOutputDim("Out", {x_dims[0], 1}); ctx->SetOutputDim("Out", {x_dims[0], 1});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册