提交 87916f8d 编写于 作者: P phlrain

simple code;test=develop

上级 165a7bd5
...@@ -34,18 +34,9 @@ class HuberLossOp : public framework::OperatorWithKernel { ...@@ -34,18 +34,9 @@ class HuberLossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x_dims.size(), 2, PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of Input(X) must be 2 and the shape is " "The rank of Input(X) must be 2 and the shape is "
"[batch_size, 1]."); "[batch_size, 1].");
if (ctx->IsRuntime()) { if (ctx->IsRuntime() ||
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
PADDLE_ENFORCE_EQ(x_dims, y_dims, "Shape of X and Y should be same"); PADDLE_ENFORCE_EQ(x_dims, y_dims, "Shape of X and Y should be same");
} else {
if (x_dims[0] != -1 && y_dims[0] != -1) {
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The dim 0 of X and Y must be the same.");
}
if (x_dims[1] != -1 && y_dims[1] != -1) {
PADDLE_ENFORCE_EQ(x_dims[1], y_dims[1],
"The dim 1 of X and Y must be the same.");
}
} }
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(x_dims[1], 1, PADDLE_ENFORCE_EQ(x_dims[1], 1,
......
...@@ -39,16 +39,11 @@ class MinusOp : public framework::OperatorWithKernel { ...@@ -39,16 +39,11 @@ class MinusOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
if (ctx->IsRuntime()) { if (ctx->IsRuntime() ||
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims, y_dims, x_dims, y_dims,
"Minus operator must take two tensor with same num of elements"); "Minus operator must take two tensor with same num of elements");
} else {
if (framework::product(x_dims) > 0 && framework::product(y_dims) > 0) {
PADDLE_ENFORCE_EQ(
x_dims, y_dims,
"Minus operator must take two tensor with same num of elements");
}
} }
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
......
...@@ -29,19 +29,10 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel { ...@@ -29,19 +29,10 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The tensor rank of X must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The tensor rank of X must be 2.");
if (ctx->IsRuntime()) { if (ctx->IsRuntime() ||
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
PADDLE_ENFORCE_EQ(x_dims, y_dims, PADDLE_ENFORCE_EQ(x_dims, y_dims,
"The shape of X and Y must be the same."); "The shape of X and Y must be the same.");
} else {
if (x_dims[0] != -1 && y_dims[0] != -1) {
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The dim 0 of X and Y must be the same.");
}
if (x_dims[1] != -1 && y_dims[1] != -1) {
PADDLE_ENFORCE_EQ(x_dims[1], y_dims[1],
"The dim 1 of X and Y must be the same.");
}
} }
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册