diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index f71af205766e0333017390e7a13fab9c3ccfd0e2..55cc5a675b46b7ecc6b36743f83cacf9f9ba3791 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -14,6 +14,39 @@ limitations under the License. */ #include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/gpu_launch_config.h" + +namespace paddle { +namespace operators { + +template +__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { + CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast(in[index]); } +} + +template +struct CastOpFunctor { + const framework::Tensor* in_; + framework::Tensor* out_; + const platform::CUDADeviceContext& ctx_; + CastOpFunctor(const framework::Tensor* in, framework::Tensor* out, + const platform::CUDADeviceContext& ctx) + : in_(in), out_(out), ctx_(ctx) {} + + template + void apply() const { + auto* in = in_->data(); + auto size = in_->numel(); + auto* out = out_->mutable_data(ctx_.GetPlace()); + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx_, size); + CastCUDAKernel<<>>(in, size, out); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators;