From 657c69bc6da889dd27892f1a86625a893cbc6977 Mon Sep 17 00:00:00 2001 From: Wilber Date: Wed, 31 Aug 2022 10:29:49 +0800 Subject: [PATCH] Fix fused cuda op's mutable data [3] (#45564) --- paddle/fluid/operators/fused/fusion_group_op.h | 14 ++++++++++---- .../fused/fusion_transpose_flatten_concat_op.cu.cc | 4 ++-- paddle/fluid/operators/fused/skip_layernorm_op.cu | 3 ++- paddle/fluid/operators/fused/yolo_box_head_op.cu | 2 +- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/fused/fusion_group_op.h b/paddle/fluid/operators/fused/fusion_group_op.h index 3c4ccb940d..bd19a630e5 100644 --- a/paddle/fluid/operators/fused/fusion_group_op.h +++ b/paddle/fluid/operators/fused/fusion_group_op.h @@ -23,17 +23,22 @@ limitations under the License. */ namespace paddle { namespace operators { +template static void MutableMultiTypeData( std::vector* var, const std::vector& data_type, + const DeviceContext& dev_ctx, const platform::Place& place) { for (size_t i = 0; i < var->size(); i++) { if (data_type[i] == framework::proto::VarType::FP32) { - (*var)[i]->mutable_data(place); + dev_ctx.template Alloc((*var)[i], + (*var)[i]->numel() * sizeof(float)); } else if (data_type[i] == framework::proto::VarType::FP16) { - (*var)[i]->mutable_data(place); + dev_ctx.template Alloc( + (*var)[i], (*var)[i]->numel() * sizeof(paddle::platform::float16)); } else if (data_type[i] == framework::proto::VarType::FP64) { - (*var)[i]->mutable_data(place); + dev_ctx.template Alloc((*var)[i], + (*var)[i]->numel() * sizeof(double)); } } } @@ -52,8 +57,9 @@ class FusionGroupKernel : public framework::OpKernel { size_t num_outs = outs.size(); auto place = ctx.GetPlace(); + auto& dev_ctx = ctx.template device_context(); - MutableMultiTypeData(&outs, outs_dtype, place); + MutableMultiTypeData(&outs, outs_dtype, dev_ctx, place); std::string func_name = ctx.Attr("func_name"); platform::DeviceCode* dev_code = diff --git a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc index 9a1e58c632..4d063ba2be 100644 --- a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc +++ b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc @@ -30,7 +30,8 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + dev_ctx.Alloc(out, out->numel() * sizeof(T)); auto odims = out->dims(); std::vector trans_axis = ctx.Attr>("trans_axis"); @@ -52,7 +53,6 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel { platform::dynload::cudnnCreateTensorDescriptor(&out_desc)); cudnnDataType_t cudnn_dtype = CudnnDataType::type; - auto& dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); T* odata = out->data(); diff --git a/paddle/fluid/operators/fused/skip_layernorm_op.cu b/paddle/fluid/operators/fused/skip_layernorm_op.cu index 3038898f2a..307d61b31a 100644 --- a/paddle/fluid/operators/fused/skip_layernorm_op.cu +++ b/paddle/fluid/operators/fused/skip_layernorm_op.cu @@ -44,7 +44,8 @@ class SkipLayerNormKernel : public framework::OpKernel { auto *out = context.Output("Out"); out->Resize(X->dims()); - auto *output_d = out->mutable_data(context.GetPlace()); + auto &dev_ctx = context.template device_context(); + auto *output_d = dev_ctx.Alloc(out, out->numel() * sizeof(T)); size_t num = 1; for (size_t i = 0; i < X->dims().size(); i++) { diff --git a/paddle/fluid/operators/fused/yolo_box_head_op.cu b/paddle/fluid/operators/fused/yolo_box_head_op.cu index b82b9a931a..f932b13d99 100644 --- a/paddle/fluid/operators/fused/yolo_box_head_op.cu +++ b/paddle/fluid/operators/fused/yolo_box_head_op.cu @@ -81,7 +81,7 @@ class YoloBoxHeadKernel : public framework::OpKernel { const int grid_size_y = h; const int anchors_num = anchors.size() / 2; const T* input_data = x->data(); - T* output_data = out->mutable_data(context.GetPlace()); + T* output_data = device_ctx.Alloc(out, out->numel() * sizeof(T)); auto stream = device_ctx.stream(); const int volume = x_dims[1] * h * w; dim3 block(16, 16, 4); -- GitLab