提交 b2dcdb51 编写于 作者: D dengkaipeng

infer shape compatable -1. test=develop

上级 82cff5ec
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h" #include "paddle/fluid/operators/grid_sampler_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
...@@ -40,10 +41,12 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -40,10 +41,12 @@ class GridSampleOp : public framework::OperatorWithKernel {
"Input(X) of GridSampleOp should be 4-D Tensor."); "Input(X) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims.size() == 4, PADDLE_ENFORCE(grid_dims.size() == 4,
"Input(Grid) of GridSampleOp should be 4-D Tensor."); "Input(Grid) of GridSampleOp should be 4-D Tensor.");
if (ctx->IsRuntime() || grid_dims[3] > 0) {
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
}
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
"Input(X) and Input(Grid) dims[0] should be equal."); "Input(X) and Input(Grid) dims[0] should be equal.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2], grid_dims[1], x_dims[2],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal."); "Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
......
...@@ -58,8 +58,13 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -58,8 +58,13 @@ class InterpolateOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
return; return;
} }
if (ctx->IsRuntime() || (out_h > 0 && out_w > 0)) {
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w}); std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
} else {
ctx->SetOutputDim("Out", dim_x);
}
} }
protected: protected:
......
...@@ -35,9 +35,11 @@ class KLDivLossOp : public framework::OperatorWithKernel { ...@@ -35,9 +35,11 @@ class KLDivLossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(), PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
"Input(X) rank and Input(Target) rank should be same."); "Input(X) rank and Input(Target) rank should be same.");
for (int i = 0; i < dim_x.size(); i++) { for (int i = 0; i < dim_x.size(); i++) {
if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) {
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i], PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
"Input(X) and Input(Target) should in same shape."); "Input(X) and Input(Target) should in same shape.");
} }
}
auto reduction = ctx->Attrs().Get<std::string>("reduction"); auto reduction = ctx->Attrs().Get<std::string>("reduction");
......
...@@ -56,13 +56,19 @@ class SpectralNormOp : public framework::OperatorWithKernel { ...@@ -56,13 +56,19 @@ class SpectralNormOp : public framework::OperatorWithKernel {
} }
auto dim_u = ctx->GetInputDim("U"); auto dim_u = ctx->GetInputDim("U");
auto dim_v = ctx->GetInputDim("V"); auto dim_v = ctx->GetInputDim("V");
if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) {
PADDLE_ENFORCE_EQ(dim_u[0], h, PADDLE_ENFORCE_EQ(dim_u[0], h,
"Input(U) dims[0] should be equal to " "Input(U) dims[0] should be equal to "
"Input(Weight) dims[Attr(dim)]"); "Input(Weight) dims[Attr(dim)]");
}
if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_v[0], w, dim_v[0], w,
"Input(V) dims[0] should be equal to " "Input(V) dims[0] should be equal to "
"the product of Input(Weight) dims except dims[Attr(dim)]"); "the product of Input(Weight) dims except dims[Attr(dim)]");
}
ctx->SetOutputDim("Out", dim_weight); ctx->SetOutputDim("Out", dim_weight);
ctx->ShareLoD("Weight", /*->*/ "Out"); ctx->ShareLoD("Weight", /*->*/ "Out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册