未验证 提交 591f0879 编写于 作者: X XiaoguangHu 提交者: GitHub

Merge pull request #16932 from SunGaofeng/infershape14

Infer shape of pad_op pad_constant_like_op for version 1.4.0
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/pad_constant_like_op.h" #include "paddle/fluid/operators/pad_constant_like_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,8 +39,16 @@ class PadConstantLikeOp : public framework::OperatorWithKernel { ...@@ -38,8 +39,16 @@ class PadConstantLikeOp : public framework::OperatorWithKernel {
"The dimention of X and Y should be the same."); "The dimention of X and Y should be the same.");
for (int i = 0; i < x_dim.size(); ++i) { for (int i = 0; i < x_dim.size(); ++i) {
PADDLE_ENFORCE_GE(x_dim[i], y_dim[i]); if ((!ctx->IsRuntime()) && ((x_dim[i] == -1) || (y_dim[i] == -1))) {
continue;
} else {
PADDLE_ENFORCE_GE(
x_dim[i], y_dim[i],
"expected X_dim[i] >= Y_dim[i], but received %d < %d for dim %d",
x_dim[i], y_dim[i], i);
} }
}
ctx->SetOutputDim("Out", x_dim); ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -162,7 +171,14 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel { ...@@ -162,7 +171,14 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
ctx->ShareLoD("Y", /*->*/ y_grad_name); ctx->ShareLoD("Y", /*->*/ y_grad_name);
for (int i = 0; i < y_dim.size(); ++i) { for (int i = 0; i < y_dim.size(); ++i) {
PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i]); if ((!ctx->IsRuntime()) && ((dout_dim[i] == -1) || (y_dim[i] == -1))) {
continue;
} else {
PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i],
"expected Out_dim[i] >= Y_dim[i], but received %d "
"< %d for dim %d",
dout_dim[i], y_dim[i], i);
}
} }
} }
} }
......
...@@ -34,10 +34,17 @@ class PadOp : public framework::OperatorWithKernel { ...@@ -34,10 +34,17 @@ class PadOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()), PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()),
"Size of paddings should be equal to 2 * dimension size " "Size of paddings should be equal to 2 * dimension size "
"of input tensor."); "of input tensor.");
for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_GE(paddings[i], 0, "paddings should >= 0.");
}
std::vector<int64_t> out_dims(x_dim.size()); std::vector<int64_t> out_dims(x_dim.size());
for (int i = 0; i < x_dim.size(); ++i) { for (int i = 0; i < x_dim.size(); ++i) {
if ((!ctx->IsRuntime()) && (x_dim[i] == -1)) {
out_dims[i] = -1;
} else {
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1]; out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
} }
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
if (out_dims[0] == x_dim[0]) { if (out_dims[0] == x_dim[0]) {
// Only pass LoD when the first dimension is equal between // Only pass LoD when the first dimension is equal between
...@@ -100,19 +107,15 @@ class PadOpGrad : public framework::OperatorWithKernel { ...@@ -100,19 +107,15 @@ class PadOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (int i = 0; i < dout_dims.size(); ++i) {
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
}
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (int i = 0; i < dout_dims.size(); ++i) { for (int i = 0; i < dout_dims.size(); ++i) {
if (ctx->IsRuntime() || (dout_dims[i] != -1)) {
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]); dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
} }
}
ctx->SetOutputDim(x_grad_name, dout_dims); ctx->SetOutputDim(x_grad_name, dout_dims);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册