未验证 提交 30d9589a 编写于 作者: Z Zhang Ting 提交者: GitHub

add cast cuda kernel (#29352)

上级 c1a26e2a
......@@ -14,6 +14,39 @@ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace operators {
template <typename InT, typename OutT>
__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); }
}
template <typename InT>
struct CastOpFunctor<platform::CUDADeviceContext, InT> {
const framework::Tensor* in_;
framework::Tensor* out_;
const platform::CUDADeviceContext& ctx_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const platform::CUDADeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto* in = in_->data<InT>();
auto size = in_->numel();
auto* out = out_->mutable_data<OutT>(ctx_.GetPlace());
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx_, size);
CastCUDAKernel<InT, OutT><<<config.block_per_grid, config.thread_per_block,
0, ctx_.stream()>>>(in, size, out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册