未验证 提交 8ac02279 编写于 作者: Y Yu Yang 提交者: GitHub

Fix the proformance problem of enforce (#6085)

* Fix Proformance problem of enforce

* Fix missing `;` in code

* Fix CI
上级 3a8311f8
...@@ -25,7 +25,7 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -25,7 +25,7 @@ class ConcatOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of ConcatOp should be empty.") "Inputs(X) of ConcatOp should be empty.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ConcatOp should not be null."); "Output(Out) of ConcatOp should not be null.");
...@@ -45,7 +45,7 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -45,7 +45,7 @@ class ConcatOp : public framework::OperatorWithKernel {
} }
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same " "Input tensors should have the same "
"elements except the specify axis.") "elements except the specify axis.");
} }
} }
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
......
...@@ -35,7 +35,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -35,7 +35,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.");
ctx->SetOutputDim("Out", x_dim); ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -120,7 +120,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -120,7 +120,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
......
...@@ -106,7 +106,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) { ...@@ -106,7 +106,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.");
if (x_dims == y_dims) { if (x_dims == y_dims) {
functor f; functor f;
......
...@@ -54,10 +54,10 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> { ...@@ -54,10 +54,10 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(length->dims()[0]), n, static_cast<size_t>(length->dims()[0]),
"The size of input-sequence and length-array should be the same") "The size of input-sequence and length-array should be the same");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(offset->dims()[0]), n, static_cast<size_t>(offset->dims()[0]),
"The size of input-sequence and offset-array should be the same") "The size of input-sequence and offset-array should be the same");
const int64_t* offset_data = offset->data<int64_t>(); const int64_t* offset_data = offset->data<int64_t>();
const int64_t* length_data = length->data<int64_t>(); const int64_t* length_data = length->data<int64_t>();
...@@ -78,11 +78,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> { ...@@ -78,11 +78,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_LT(0, offset_data[i], PADDLE_ENFORCE_LT(0, offset_data[i],
"The offset[%d] must greater than zero.", i) "The offset[%d] must greater than zero.", i);
PADDLE_ENFORCE_LT(0, length_data[i], PADDLE_ENFORCE_LT(0, length_data[i],
"The length[%d] must greater than zero.", i) "The length[%d] must greater than zero.", i);
PADDLE_ENFORCE_LT(lod[0][i] + offset_data[i] + length_data[i], PADDLE_ENFORCE_LT(lod[0][i] + offset_data[i] + length_data[i],
lod[0][i + 1], "The target tensor's length overflow.") lod[0][i + 1], "The target tensor's length overflow.");
} }
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
......
...@@ -84,7 +84,7 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -84,7 +84,7 @@ class SumKernel : public framework::OpKernel<T> {
int64_t offset = 0; int64_t offset = 0;
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
PADDLE_ENFORCE_EQ(out->height(), PADDLE_ENFORCE_EQ(out->height(),
in_vars[i]->Get<SelectedRows>().height()) in_vars[i]->Get<SelectedRows>().height());
functor(context.device_context(), in_vars[i]->Get<SelectedRows>(), functor(context.device_context(), in_vars[i]->Get<SelectedRows>(),
offset, out); offset, out);
offset += in_vars[i]->Get<SelectedRows>().value().numel(); offset += in_vars[i]->Get<SelectedRows>().value().numel();
......
...@@ -234,16 +234,24 @@ inline void throw_on_error(T e) { ...@@ -234,16 +234,24 @@ inline void throw_on_error(T e) {
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__) __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ #define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \ #define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
PADDLE_ENFORCE(nullptr != (__VAL), #__VAL " should not be null\n%s", \ do { \
paddle::string::Sprintf("" __VA_ARGS__)); if (UNLIKELY(nullptr == (__VAL))) { \
PADDLE_THROW(#__VAL " should not be null\n%s", \
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ paddle::string::Sprintf("" __VA_ARGS__)); \
PADDLE_ENFORCE(__VAL0 __CMP __VAL1, \ } \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \ } while (0)
#__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \
paddle::string::to_string(__VAL1), \ #define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
paddle::string::Sprintf("" __VA_ARGS__)); do { \
if (!UNLIKELY((__VAL0)__CMP(__VAL1))) { \
PADDLE_THROW("enforce %s " #__CMP " %s failed, %s " #__INV_CMP \
" %s\n%s", \
#__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \
paddle::string::to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__)); \
} \
} while (0)
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册