未验证 提交 9845b222 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #16877 from heavengate/fix_infer_shape_pick

[cherry-pick] infer shape: grid_sampler, kldiv_loss, spectral_norm, interpolate
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h" #include "paddle/fluid/operators/grid_sampler_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
...@@ -40,10 +41,12 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -40,10 +41,12 @@ class GridSampleOp : public framework::OperatorWithKernel {
"Input(X) of GridSampleOp should be 4-D Tensor."); "Input(X) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims.size() == 4, PADDLE_ENFORCE(grid_dims.size() == 4,
"Input(Grid) of GridSampleOp should be 4-D Tensor."); "Input(Grid) of GridSampleOp should be 4-D Tensor.");
if (ctx->IsRuntime() || grid_dims[3] > 0) {
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
}
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
"Input(X) and Input(Grid) dims[0] should be equal."); "Input(X) and Input(Grid) dims[0] should be equal.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2], grid_dims[1], x_dims[2],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal."); "Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
......
...@@ -40,6 +40,8 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -40,6 +40,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
int out_h = ctx->Attrs().Get<int>("out_h"); int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w"); int out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize"); auto out_size_dim = ctx->GetInputDim("OutSize");
...@@ -49,6 +51,7 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -49,6 +51,7 @@ class InterpolateOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
return; return;
} }
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w}); std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
} }
......
...@@ -35,9 +35,11 @@ class KLDivLossOp : public framework::OperatorWithKernel { ...@@ -35,9 +35,11 @@ class KLDivLossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(), PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
"Input(X) rank and Input(Target) rank should be same."); "Input(X) rank and Input(Target) rank should be same.");
for (int i = 0; i < dim_x.size(); i++) { for (int i = 0; i < dim_x.size(); i++) {
if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) {
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i], PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
"Input(X) and Input(Target) should in same shape."); "Input(X) and Input(Target) should in same shape.");
} }
}
auto reduction = ctx->Attrs().Get<std::string>("reduction"); auto reduction = ctx->Attrs().Get<std::string>("reduction");
......
...@@ -56,13 +56,19 @@ class SpectralNormOp : public framework::OperatorWithKernel { ...@@ -56,13 +56,19 @@ class SpectralNormOp : public framework::OperatorWithKernel {
} }
auto dim_u = ctx->GetInputDim("U"); auto dim_u = ctx->GetInputDim("U");
auto dim_v = ctx->GetInputDim("V"); auto dim_v = ctx->GetInputDim("V");
if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) {
PADDLE_ENFORCE_EQ(dim_u[0], h, PADDLE_ENFORCE_EQ(dim_u[0], h,
"Input(U) dims[0] should be equal to " "Input(U) dims[0] should be equal to "
"Input(Weight) dims[Attr(dim)]"); "Input(Weight) dims[Attr(dim)]");
}
if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_v[0], w, dim_v[0], w,
"Input(V) dims[0] should be equal to " "Input(V) dims[0] should be equal to "
"the product of Input(Weight) dims except dims[Attr(dim)]"); "the product of Input(Weight) dims except dims[Attr(dim)]");
}
ctx->SetOutputDim("Out", dim_weight); ctx->SetOutputDim("Out", dim_weight);
ctx->ShareLoD("Weight", /*->*/ "Out"); ctx->ShareLoD("Weight", /*->*/ "Out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册