未验证 提交 f3456071 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Refine transpose flatten concat error message (#23625)

* refine fusion_transpose_flatten_concat_op log
test=develop

* fix ci error
test=develop
上级 17bee1d9
......@@ -27,14 +27,20 @@ class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of ConcatOp should be empty.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ConcatOp should not be null.");
PADDLE_ENFORCE_GE(
ctx->Inputs("X").size(), 1UL,
platform::errors::InvalidArgument(
"Inputs(X) of TransposeFlattenConcat op should not be empty."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Inputs(X) of TransposeFlattenConcat op should not be empty."));
auto ins = ctx->GetInputsDim("X");
const size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0.");
PADDLE_ENFORCE_GT(n, 0,
platform::errors::InvalidArgument(
"Input tensors dim size should greater than 0."));
std::vector<int> trans_axis =
ctx->Attrs().Get<std::vector<int>>("trans_axis");
......@@ -44,9 +50,10 @@ class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel {
size_t x_rank = ins[0].size();
size_t trans_axis_size = trans_axis.size();
PADDLE_ENFORCE_EQ(x_rank, trans_axis_size,
"The input tensor's rank(%d) "
"should be equal to the permutation axis's size(%d)",
x_rank, trans_axis_size);
platform::errors::InvalidArgument(
"The input tensor's rank(%d) "
"should be equal to the permutation axis's size(%d)",
x_rank, trans_axis_size));
auto dims0 =
GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[0]));
......@@ -59,9 +66,10 @@ class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel {
out_dims[concat_axis] += dimsi[j];
} else {
PADDLE_ENFORCE_EQ(out_dims[j], dimsi[j],
"After flatting, the %d-th dim should be save "
"except the specify axis.",
j);
platform::errors::InvalidArgument(
"After flatting, the %d-th dim should be save "
"except the specify axis.",
j));
}
}
}
......
......@@ -46,9 +46,13 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
cudnnTensorDescriptor_t in_desc;
cudnnTensorDescriptor_t out_desc;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&in_desc));
platform::dynload::cudnnCreateTensorDescriptor(&in_desc),
platform::errors::External("Create cudnn tensor descriptor failed in "
"transpose_flatten_concat_fusion op."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&out_desc));
platform::dynload::cudnnCreateTensorDescriptor(&out_desc),
platform::errors::External("Create cudnn tensor descriptor failed in "
"transpose_flatten_concat_fusion op."));
cudnnDataType_t cudnn_dtype = CudnnDataType<T>::type;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
......@@ -87,15 +91,24 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
dims_y[i] = 1;
}
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnTransformTensor(
handle, CudnnDataType<T>::kOne(), in_desc,
static_cast<const void*>(ins[k]->data<T>()),
CudnnDataType<T>::kZero(), out_desc, static_cast<void*>(odata)));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetTensorNdDescriptor(
in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data()),
platform::errors::External("Create cudnn tensorNd descriptor failed "
"in transpose_flatten_concat op."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetTensorNdDescriptor(
out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data()),
platform::errors::External("Create cudnn tensorNd descriptor failed "
"in transpose_flatten_concat op."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnTransformTensor(
handle, CudnnDataType<T>::kOne(), in_desc,
static_cast<const void*>(ins[k]->data<T>()),
CudnnDataType<T>::kZero(), out_desc, static_cast<void*>(odata)),
platform::errors::External("Create cudnn transform tensor failed in "
"transpose_flatten_concat op."));
if (concat_axis == 0) {
odata += osize;
} else {
......@@ -104,9 +117,13 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
}
}
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(in_desc));
platform::dynload::cudnnDestroyTensorDescriptor(in_desc),
platform::errors::External(
"Destory cudnn descriptor failed in transpose_flatten_concat op."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(out_desc));
platform::dynload::cudnnDestroyTensorDescriptor(out_desc),
platform::errors::External(
"Destory cudnn descriptor failed in transpose_flatten_concat op."));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册