未验证 提交 8ac02279 编写于 作者: Y Yu Yang 提交者: GitHub

Fix the proformance problem of enforce (#6085)

* Fix Proformance problem of enforce

* Fix missing `;` in code

* Fix CI
上级 3a8311f8
......@@ -25,7 +25,7 @@ class ConcatOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of ConcatOp should be empty.")
"Inputs(X) of ConcatOp should be empty.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ConcatOp should not be null.");
......@@ -45,7 +45,7 @@ class ConcatOp : public framework::OperatorWithKernel {
}
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.")
"elements except the specify axis.");
}
}
ctx->SetOutputDim("Out", out_dims);
......
......@@ -35,7 +35,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.")
"Rank of first input must >= rank of second input.");
ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -120,7 +120,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
"Rank of first input must >= rank of second input.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
......
......@@ -106,7 +106,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
"Rank of first input must >= rank of second input.");
if (x_dims == y_dims) {
functor f;
......
......@@ -54,10 +54,10 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(length->dims()[0]),
"The size of input-sequence and length-array should be the same")
"The size of input-sequence and length-array should be the same");
PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(offset->dims()[0]),
"The size of input-sequence and offset-array should be the same")
"The size of input-sequence and offset-array should be the same");
const int64_t* offset_data = offset->data<int64_t>();
const int64_t* length_data = length->data<int64_t>();
......@@ -78,11 +78,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_LT(0, offset_data[i],
"The offset[%d] must greater than zero.", i)
"The offset[%d] must greater than zero.", i);
PADDLE_ENFORCE_LT(0, length_data[i],
"The length[%d] must greater than zero.", i)
"The length[%d] must greater than zero.", i);
PADDLE_ENFORCE_LT(lod[0][i] + offset_data[i] + length_data[i],
lod[0][i + 1], "The target tensor's length overflow.")
lod[0][i + 1], "The target tensor's length overflow.");
}
out->mutable_data<T>(ctx.GetPlace());
......
......@@ -84,7 +84,7 @@ class SumKernel : public framework::OpKernel<T> {
int64_t offset = 0;
for (int i = 0; i < N; i++) {
PADDLE_ENFORCE_EQ(out->height(),
in_vars[i]->Get<SelectedRows>().height())
in_vars[i]->Get<SelectedRows>().height());
functor(context.device_context(), in_vars[i]->Get<SelectedRows>(),
offset, out);
offset += in_vars[i]->Get<SelectedRows>().value().numel();
......
......@@ -234,16 +234,24 @@ inline void throw_on_error(T e) {
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
PADDLE_ENFORCE(nullptr != (__VAL), #__VAL " should not be null\n%s", \
paddle::string::Sprintf("" __VA_ARGS__));
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
PADDLE_ENFORCE(__VAL0 __CMP __VAL1, \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
#__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \
paddle::string::to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__));
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
do { \
if (UNLIKELY(nullptr == (__VAL))) { \
PADDLE_THROW(#__VAL " should not be null\n%s", \
paddle::string::Sprintf("" __VA_ARGS__)); \
} \
} while (0)
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
do { \
if (!UNLIKELY((__VAL0)__CMP(__VAL1))) { \
PADDLE_THROW("enforce %s " #__CMP " %s failed, %s " #__INV_CMP \
" %s\n%s", \
#__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \
paddle::string::to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__)); \
} \
} while (0)
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册