提交 8b59ac3a 编写于 作者: C Chen Weihang 提交者: Tao Luo

delete paddle infershape enforce marco (#20832)

上级 c8e49be2
...@@ -184,31 +184,35 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { ...@@ -184,31 +184,35 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
true, true,
"The Input(Label) should be a 3-D tensor with last " "The Input(Label) should be a 3-D tensor with last "
"dimension fixed to 1 or a 2-D tensor in padding mode."); "dimension fixed to 1 or a 2-D tensor in padding mode.");
PADDLE_INFERSHAPE_ENFORCE_EQ( if (ctx->IsRuntime()) {
ctx, emission_dims[0], label_dims[0], PADDLE_ENFORCE_EQ(emission_dims[0], label_dims[0],
"The batch size of Input(Emission) and Input(Label) " "The batch size of Input(Emission) and Input(Label) "
"should be the same."); "should be the same.");
PADDLE_INFERSHAPE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(emission_dims[1], label_dims[1],
ctx, emission_dims[1], label_dims[1],
"The max length of Input(Emission) and Input(Label) " "The max length of Input(Emission) and Input(Label) "
"should be the same."); "should be the same.");
}
} else { } else {
PADDLE_ENFORCE_EQ(emission_dims.size(), 2, PADDLE_ENFORCE_EQ(emission_dims.size(), 2,
"The Input(Emission) should be a 2-D tensor."); "The Input(Emission) should be a 2-D tensor.");
PADDLE_INFERSHAPE_ENFORCE_EQ( if (ctx->IsRuntime()) {
ctx, emission_dims[1], transition_dims[1], PADDLE_ENFORCE_EQ(emission_dims[1], transition_dims[1],
"The 2nd dimension of the Input(Emission) and the Input(Transition) " "The 2nd dimension of the Input(Emission) and the "
"Input(Transition) "
"should be equal to the tag number."); "should be equal to the tag number.");
}
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 2, PADDLE_ENFORCE_EQ(label_dims.size(), 2,
"The Input(Label) should be a 2-D tensor with the 2nd " "The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1."); "dimensions fixed to 1.");
PADDLE_INFERSHAPE_ENFORCE_EQ( if (ctx->IsRuntime()) {
ctx, emission_dims[0], label_dims[0], PADDLE_ENFORCE_EQ(
emission_dims[0], label_dims[0],
"The height of Input(Emission) and the height of Input(Label) " "The height of Input(Emission) and the height of Input(Label) "
"should be the same."); "should be the same.");
} }
}
ctx->SetOutputDim("Alpha", emission_dims); ctx->SetOutputDim("Alpha", emission_dims);
ctx->SetOutputDim("EmissionExps", emission_dims); ctx->SetOutputDim("EmissionExps", emission_dims);
ctx->SetOutputDim("TransitionExps", transition_dims); ctx->SetOutputDim("TransitionExps", transition_dims);
......
...@@ -45,19 +45,20 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -45,19 +45,20 @@ class AccuracyOp : public framework::OperatorWithKernel {
"ShapeError: label's dimensions of AccuracyOp must be 2. " "ShapeError: label's dimensions of AccuracyOp must be 2. "
"But received label's dimensions = %d, label's shape = [%s]", "But received label's dimensions = %d, label's shape = [%s]",
label_dim.size(), label_dim); label_dim.size(), label_dim);
PADDLE_INFERSHAPE_ENFORCE_EQ( if (ctx->IsRuntime()) {
ctx, label_dim[1], 1, PADDLE_ENFORCE_EQ(label_dim[1], 1,
"ShapeError: label's second dimension of " "ShapeError: label's second dimension of "
"AccuracyOp must be 1. But received label's " "AccuracyOp must be 1. But received label's "
"second dimension is = %d, label's shape = [%s]", "second dimension is = %d, label's shape = [%s]",
label_dim[1], label_dim); label_dim[1], label_dim);
PADDLE_INFERSHAPE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx, inference_dim[0], label_dim[0], inference_dim[0], label_dim[0],
"ShapeError: the output's num_rows of AccuracyOp must be" "ShapeError: the output's num_rows of AccuracyOp must be"
" the same as label's num_rows. But received output's " " the same as label's num_rows. But received output's "
"shape = [%s], label's shape = [%s], output's num_rows = %d, label's " "shape = [%s], label's shape = [%s], output's num_rows = %d, label's "
"num_rows = %d", "num_rows = %d",
inference_dim, label_dim, inference_dim[0], label_dim[0]); inference_dim, label_dim, inference_dim[0], label_dim[0]);
}
ctx->SetOutputDim("Accuracy", {1}); ctx->SetOutputDim("Accuracy", {1});
ctx->SetOutputDim("Correct", {1}); ctx->SetOutputDim("Correct", {1});
......
...@@ -28,14 +28,18 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -28,14 +28,18 @@ class AucOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input of Label should not be null."); "Input of Label should not be null.");
auto predict_width = ctx->GetInputDim("Predict")[1]; auto predict_width = ctx->GetInputDim("Predict")[1];
PADDLE_INFERSHAPE_ENFORCE_LE(ctx, predict_width, 2, if (ctx->IsRuntime()) {
PADDLE_ENFORCE_LE(predict_width, 2,
"Only support binary classification," "Only support binary classification,"
"prediction dims[1] should be 1 or 2"); "prediction dims[1] should be 1 or 2");
}
auto predict_height = ctx->GetInputDim("Predict")[0]; auto predict_height = ctx->GetInputDim("Predict")[0];
auto label_height = ctx->GetInputDim("Label")[0]; auto label_height = ctx->GetInputDim("Label")[0];
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_height, label_height, if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(predict_height, label_height,
"Out and Label should have same height."); "Out and Label should have same height.");
}
int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1; int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1;
int slide_steps = ctx->Attrs().Get<int>("slide_steps"); int slide_steps = ctx->Attrs().Get<int>("slide_steps");
......
...@@ -135,11 +135,13 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { ...@@ -135,11 +135,13 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(out_dims.size(), 2, PADDLE_ENFORCE_GE(out_dims.size(), 2,
"The tensor rank of Input(Out@Grad) should be 2."); "The tensor rank of Input(Out@Grad) should be 2.");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[0], in_dims[0], if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0],
"The 1st dimension of Input(Out@Grad) must be " "The 1st dimension of Input(Out@Grad) must be "
"same as input."); "same as input.");
PADDLE_INFERSHAPE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(out_dims[1], 1,
ctx, out_dims[1], 1, "The 2nd dimension of Input(Out@Grad) must be 1."); "The 2nd dimension of Input(Out@Grad) must be 1.");
}
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
......
...@@ -137,12 +137,14 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { ...@@ -137,12 +137,14 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[0], x_dims[0], if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0],
"First dimension of output gradient and " "First dimension of output gradient and "
"input value must be equal."); "input value must be equal.");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[1], 1, PADDLE_ENFORCE_EQ(out_dims[1], 1,
"Second dimension of output gradient " "Second dimension of output gradient "
"must be 1."); "must be 1.");
}
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) ctx->SetOutputDim(x_grad_name, x_dims); if (ctx->HasOutput(x_grad_name)) ctx->SetOutputDim(x_grad_name, x_dims);
......
...@@ -484,46 +484,5 @@ struct BinaryCompareMessageConverter<false> { ...@@ -484,46 +484,5 @@ struct BinaryCompareMessageConverter<false> {
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ #define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
#define __PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL1, __VAL2, __CMP, \
__INV_CMP, ...) \
do { \
auto __val1 = (__VAL1); \
auto __val2 = (__VAL2); \
if (!__CTX->IsRuntime()) { \
if (__val1 == -1 || __val2 == -1) { \
break; \
} \
} \
using __TYPE1__ = decltype(__val1); \
using __TYPE2__ = decltype(__val2); \
using __COMMON_TYPE1__ = \
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
using __COMMON_TYPE2__ = \
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
static_cast<__COMMON_TYPE2__>(__val2)); \
if (UNLIKELY(!__is_not_error)) { \
PADDLE_THROW("Expected %s " #__CMP " %s, but received %s:%s " #__INV_CMP \
" %s:%s.\n%s", \
#__VAL1, #__VAL2, #__VAL1, \
::paddle::string::to_string(__val1), #__VAL2, \
::paddle::string::to_string(__val2), \
::paddle::string::Sprintf(__VA_ARGS__)); \
} \
} while (0)
#define PADDLE_INFERSHAPE_ENFORCE_EQ(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, ==, !=, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_NE(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, !=, ==, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_GT(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, >, <=, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_GE(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, >=, <, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_LT(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_LE(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, <=, >, __VA_ARGS__)
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册