未验证 提交 f3456071 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Refine transpose flatten concat error message (#23625)

* refine fusion_transpose_flatten_concat_op log
test=develop

* fix ci error
test=develop
上级 17bee1d9
...@@ -27,14 +27,20 @@ class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel { ...@@ -27,14 +27,20 @@ class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, PADDLE_ENFORCE_GE(
"Inputs(X) of ConcatOp should be empty."); ctx->Inputs("X").size(), 1UL,
PADDLE_ENFORCE(ctx->HasOutput("Out"), platform::errors::InvalidArgument(
"Output(Out) of ConcatOp should not be null."); "Inputs(X) of TransposeFlattenConcat op should not be empty."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Inputs(X) of TransposeFlattenConcat op should not be empty."));
auto ins = ctx->GetInputsDim("X"); auto ins = ctx->GetInputsDim("X");
const size_t n = ins.size(); const size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0."); PADDLE_ENFORCE_GT(n, 0,
platform::errors::InvalidArgument(
"Input tensors dim size should greater than 0."));
std::vector<int> trans_axis = std::vector<int> trans_axis =
ctx->Attrs().Get<std::vector<int>>("trans_axis"); ctx->Attrs().Get<std::vector<int>>("trans_axis");
...@@ -44,9 +50,10 @@ class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel { ...@@ -44,9 +50,10 @@ class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel {
size_t x_rank = ins[0].size(); size_t x_rank = ins[0].size();
size_t trans_axis_size = trans_axis.size(); size_t trans_axis_size = trans_axis.size();
PADDLE_ENFORCE_EQ(x_rank, trans_axis_size, PADDLE_ENFORCE_EQ(x_rank, trans_axis_size,
platform::errors::InvalidArgument(
"The input tensor's rank(%d) " "The input tensor's rank(%d) "
"should be equal to the permutation axis's size(%d)", "should be equal to the permutation axis's size(%d)",
x_rank, trans_axis_size); x_rank, trans_axis_size));
auto dims0 = auto dims0 =
GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[0])); GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[0]));
...@@ -59,9 +66,10 @@ class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel { ...@@ -59,9 +66,10 @@ class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel {
out_dims[concat_axis] += dimsi[j]; out_dims[concat_axis] += dimsi[j];
} else { } else {
PADDLE_ENFORCE_EQ(out_dims[j], dimsi[j], PADDLE_ENFORCE_EQ(out_dims[j], dimsi[j],
platform::errors::InvalidArgument(
"After flatting, the %d-th dim should be save " "After flatting, the %d-th dim should be save "
"except the specify axis.", "except the specify axis.",
j); j));
} }
} }
} }
......
...@@ -46,9 +46,13 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> { ...@@ -46,9 +46,13 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
cudnnTensorDescriptor_t in_desc; cudnnTensorDescriptor_t in_desc;
cudnnTensorDescriptor_t out_desc; cudnnTensorDescriptor_t out_desc;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&in_desc)); platform::dynload::cudnnCreateTensorDescriptor(&in_desc),
platform::errors::External("Create cudnn tensor descriptor failed in "
"transpose_flatten_concat_fusion op."));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&out_desc)); platform::dynload::cudnnCreateTensorDescriptor(&out_desc),
platform::errors::External("Create cudnn tensor descriptor failed in "
"transpose_flatten_concat_fusion op."));
cudnnDataType_t cudnn_dtype = CudnnDataType<T>::type; cudnnDataType_t cudnn_dtype = CudnnDataType<T>::type;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
...@@ -87,15 +91,24 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> { ...@@ -87,15 +91,24 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
dims_y[i] = 1; dims_y[i] = 1;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(
in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data())); platform::dynload::cudnnSetTensorNdDescriptor(
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data()),
out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data())); platform::errors::External("Create cudnn tensorNd descriptor failed "
"in transpose_flatten_concat op."));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetTensorNdDescriptor(
out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data()),
platform::errors::External("Create cudnn tensorNd descriptor failed "
"in transpose_flatten_concat op."));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnTransformTensor( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnTransformTensor(
handle, CudnnDataType<T>::kOne(), in_desc, handle, CudnnDataType<T>::kOne(), in_desc,
static_cast<const void*>(ins[k]->data<T>()), static_cast<const void*>(ins[k]->data<T>()),
CudnnDataType<T>::kZero(), out_desc, static_cast<void*>(odata))); CudnnDataType<T>::kZero(), out_desc, static_cast<void*>(odata)),
platform::errors::External("Create cudnn transform tensor failed in "
"transpose_flatten_concat op."));
if (concat_axis == 0) { if (concat_axis == 0) {
odata += osize; odata += osize;
} else { } else {
...@@ -104,9 +117,13 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> { ...@@ -104,9 +117,13 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
} }
} }
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(in_desc)); platform::dynload::cudnnDestroyTensorDescriptor(in_desc),
platform::errors::External(
"Destory cudnn descriptor failed in transpose_flatten_concat op."));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(out_desc)); platform::dynload::cudnnDestroyTensorDescriptor(out_desc),
platform::errors::External(
"Destory cudnn descriptor failed in transpose_flatten_concat op."));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册