提交 70a967df 编写于 作者: D dengkaipeng

infer shape compatable -1. test=develop

上级 cf5af3b5
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
......@@ -40,10 +41,12 @@ class GridSampleOp : public framework::OperatorWithKernel {
"Input(X) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims.size() == 4,
"Input(Grid) of GridSampleOp should be 4-D Tensor.");
if (ctx->IsRuntime() || grid_dims[3] > 0) {
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
}
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
"Input(X) and Input(Grid) dims[0] should be equal.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
......
......@@ -49,8 +49,13 @@ class InterpolateOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "Out");
return;
}
if (ctx->IsRuntime() || (out_h > 0 && out_w > 0)) {
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
} else {
ctx->SetOutputDim("Out", dim_x);
}
}
protected:
......
......@@ -35,9 +35,11 @@ class KLDivLossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
"Input(X) rank and Input(Target) rank should be same.");
for (int i = 0; i < dim_x.size(); i++) {
if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) {
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
"Input(X) and Input(Target) should in same shape.");
}
}
auto reduction = ctx->Attrs().Get<std::string>("reduction");
......
......@@ -56,13 +56,19 @@ class SpectralNormOp : public framework::OperatorWithKernel {
}
auto dim_u = ctx->GetInputDim("U");
auto dim_v = ctx->GetInputDim("V");
if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) {
PADDLE_ENFORCE_EQ(dim_u[0], h,
"Input(U) dims[0] should be equal to "
"Input(Weight) dims[Attr(dim)]");
}
if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) {
PADDLE_ENFORCE_EQ(
dim_v[0], w,
"Input(V) dims[0] should be equal to "
"the product of Input(Weight) dims except dims[Attr(dim)]");
}
ctx->SetOutputDim("Out", dim_weight);
ctx->ShareLoD("Weight", /*->*/ "Out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册