未验证 提交 dab49205 编写于 作者: Z Zhang Ting 提交者: GitHub

improve performance of cast op (#28727)

上级 d12aa495
......@@ -15,11 +15,14 @@ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/float16.h"
template <typename T>
using CastOpKernel =
paddle::operators::CastOpKernel<paddle::platform::CUDADeviceContext, T>;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
CastOpKernel<int>, CastOpKernel<int64_t>,
CastOpKernel<bool>, CastOpKernel<uint8_t>,
CastOpKernel<paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
......@@ -48,17 +48,41 @@ struct CastOpFunctor {
}
};
template <typename DeviceContext, typename InT, typename OutT>
static void CastFunction(const framework::ExecutionContext& context) {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto in_t = framework::EigenVector<InT>::Flatten(*in);
out->mutable_data<OutT>(context.GetPlace());
auto out_t = framework::EigenVector<OutT>::Flatten(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_t.device(place) = in_t.template cast<OutT>();
}
template <typename DeviceContext, typename InT>
class CastOpKernel : public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")),
CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>()));
auto out_type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype"));
if (out_type == paddle::framework::proto::VarType::FP64) {
CastFunction<DeviceContext, InT, double>(context);
} else if (out_type == paddle::framework::proto::VarType::FP32) {
CastFunction<DeviceContext, InT, float>(context);
} else if (out_type == paddle::framework::proto::VarType::FP16) {
CastFunction<DeviceContext, InT, paddle::platform::float16>(context);
} else if (out_type == paddle::framework::proto::VarType::INT64) {
CastFunction<DeviceContext, InT, int64_t>(context);
} else if (out_type == paddle::framework::proto::VarType::INT32) {
CastFunction<DeviceContext, InT, int>(context);
} else if (out_type == paddle::framework::proto::VarType::UINT8) {
CastFunction<DeviceContext, InT, uint8_t>(context);
} else if (out_type == paddle::framework::proto::VarType::BOOL) {
CastFunction<DeviceContext, InT, bool>(context);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册