未验证 提交 779ffb84 编写于 作者: H Hongyu Liu 提交者: GitHub

Merge pull request #16876 from tink2123/infer_shape

modified infer shape for op
...@@ -79,9 +79,13 @@ class AffineChannelOp : public framework::OperatorWithKernel { ...@@ -79,9 +79,13 @@ class AffineChannelOp : public framework::OperatorWithKernel {
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL); PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL);
PADDLE_ENFORCE_EQ(scale_dims[0], C);
PADDLE_ENFORCE_EQ(b_dims.size(), 1UL); PADDLE_ENFORCE_EQ(b_dims.size(), 1UL);
if (ctx->IsRuntime() || scale_dims[0] > 0) {
PADDLE_ENFORCE_EQ(scale_dims[0], C);
}
if (ctx->IsRuntime() || b_dims[0] > 0) {
PADDLE_ENFORCE_EQ(b_dims[0], C); PADDLE_ENFORCE_EQ(b_dims[0], C);
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
......
...@@ -68,10 +68,15 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -68,10 +68,15 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
if ((!ctx->IsRuntime()) &&
(in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i], dilations[i], paddings[i],
strides[i])); strides[i]));
} }
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output"); ctx->ShareLoD("Input", "Output");
} }
......
...@@ -51,8 +51,10 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -51,8 +51,10 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(label_dims.size(), 2, PADDLE_ENFORCE_EQ(label_dims.size(), 2,
"The rank of Input(Label) must be 2, " "The rank of Input(Label) must be 2, "
"the shape is [N, 6]."); "the shape is [N, 6].");
if (ctx->IsRuntime() || label_dims[1] > 0) {
PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5, PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5,
"The shape of Input(Label) is [N, 6] or [N, 5]."); "The shape of Input(Label) is [N, 6] or [N, 5].");
}
if (ctx->HasInput("PosCount")) { if (ctx->HasInput("PosCount")) {
PADDLE_ENFORCE(ctx->HasInput("TruePos"), PADDLE_ENFORCE(ctx->HasInput("TruePos"),
......
...@@ -45,9 +45,12 @@ class RowConvOp : public framework::OperatorWithKernel { ...@@ -45,9 +45,12 @@ class RowConvOp : public framework::OperatorWithKernel {
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2."); PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2.");
if (ctx->IsRuntime() || (x_dims[1] > 0 && filter_dims[1] > 0)) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[1], filter_dims[1], x_dims[1], filter_dims[1],
"The 2nd dimension of Input(X) and Input(Filter) should be same."); "The 2nd dimension of Input(X) and Input(Filter) should be same.");
}
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
} }
......
...@@ -99,11 +99,16 @@ class UnpoolOp : public framework::OperatorWithKernel { ...@@ -99,11 +99,16 @@ class UnpoolOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(in_x_dims.size() == 4, PADDLE_ENFORCE(in_x_dims.size() == 4,
"Unpooling intput must be of 4-dimensional."); "Unpooling intput must be of 4-dimensional.");
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims); PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]}); std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
output_shape.push_back(-1);
} else {
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i], output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i])); paddings[i], strides[i]));
} }
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册