未验证 提交 dab49205 编写于 作者: Z Zhang Ting 提交者: GitHub

improve performance of cast op (#28727)

上级 d12aa495
...@@ -15,11 +15,14 @@ limitations under the License. */ ...@@ -15,11 +15,14 @@ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
template <typename T> namespace ops = paddle::operators;
using CastOpKernel =
paddle::operators::CastOpKernel<paddle::platform::CUDADeviceContext, T>;
REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>, REGISTER_OP_CUDA_KERNEL(
CastOpKernel<int>, CastOpKernel<int64_t>, cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>,
CastOpKernel<bool>, CastOpKernel<uint8_t>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>,
CastOpKernel<paddle::platform::float16>); ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
...@@ -48,17 +48,41 @@ struct CastOpFunctor { ...@@ -48,17 +48,41 @@ struct CastOpFunctor {
} }
}; };
template <typename DeviceContext, typename InT, typename OutT>
static void CastFunction(const framework::ExecutionContext& context) {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto in_t = framework::EigenVector<InT>::Flatten(*in);
out->mutable_data<OutT>(context.GetPlace());
auto out_t = framework::EigenVector<OutT>::Flatten(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_t.device(place) = in_t.template cast<OutT>();
}
template <typename DeviceContext, typename InT> template <typename DeviceContext, typename InT>
class CastOpKernel : public framework::OpKernel<InT> { class CastOpKernel : public framework::OpKernel<InT> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto out_type = static_cast<framework::proto::VarType::Type>(
auto* out = context.Output<framework::Tensor>("Out"); context.Attr<int>("out_dtype"));
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>( if (out_type == paddle::framework::proto::VarType::FP64) {
context.Attr<int>("out_dtype")), CastFunction<DeviceContext, InT, double>(context);
CastOpFunctor<DeviceContext, InT>( } else if (out_type == paddle::framework::proto::VarType::FP32) {
in, out, context.template device_context<DeviceContext>())); CastFunction<DeviceContext, InT, float>(context);
} else if (out_type == paddle::framework::proto::VarType::FP16) {
CastFunction<DeviceContext, InT, paddle::platform::float16>(context);
} else if (out_type == paddle::framework::proto::VarType::INT64) {
CastFunction<DeviceContext, InT, int64_t>(context);
} else if (out_type == paddle::framework::proto::VarType::INT32) {
CastFunction<DeviceContext, InT, int>(context);
} else if (out_type == paddle::framework::proto::VarType::UINT8) {
CastFunction<DeviceContext, InT, uint8_t>(context);
} else if (out_type == paddle::framework::proto::VarType::BOOL) {
CastFunction<DeviceContext, InT, bool>(context);
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册