From dab49205684012411a1001be7c2e1117ae80a561 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Fri, 20 Nov 2020 10:17:38 +0800 Subject: [PATCH] improve performance of cast op (#28727) --- paddle/fluid/operators/cast_op.cu | 17 ++++++++------ paddle/fluid/operators/cast_op.h | 38 +++++++++++++++++++++++++------ 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 657d162878c..422adfdbb50 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -15,11 +15,14 @@ limitations under the License. */ #include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/platform/float16.h" -template -using CastOpKernel = - paddle::operators::CastOpKernel; +namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel, CastOpKernel, - CastOpKernel, CastOpKernel, - CastOpKernel, CastOpKernel, - CastOpKernel); +REGISTER_OP_CUDA_KERNEL( + cast, ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel); diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index 8fa0416049f..66079243eb4 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -48,17 +48,41 @@ struct CastOpFunctor { } }; +template +static void CastFunction(const framework::ExecutionContext& context) { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + + auto in_t = framework::EigenVector::Flatten(*in); + out->mutable_data(context.GetPlace()); + auto out_t = framework::EigenVector::Flatten(*out); + auto& place = + *context.template device_context().eigen_device(); + out_t.device(place) = in_t.template cast(); +} + template class CastOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - framework::VisitDataType( - static_cast( - context.Attr("out_dtype")), - CastOpFunctor( - in, out, context.template device_context())); + auto out_type = static_cast( + context.Attr("out_dtype")); + + if (out_type == paddle::framework::proto::VarType::FP64) { + CastFunction(context); + } else if (out_type == paddle::framework::proto::VarType::FP32) { + CastFunction(context); + } else if (out_type == paddle::framework::proto::VarType::FP16) { + CastFunction(context); + } else if (out_type == paddle::framework::proto::VarType::INT64) { + CastFunction(context); + } else if (out_type == paddle::framework::proto::VarType::INT32) { + CastFunction(context); + } else if (out_type == paddle::framework::proto::VarType::UINT8) { + CastFunction(context); + } else if (out_type == paddle::framework::proto::VarType::BOOL) { + CastFunction(context); + } } }; -- GitLab