未验证 提交 3dd71f33 编写于 作者: S Shijie 提交者: GitHub

Accelerate scalar math cud (#6599)

* fix typo

* use cuda elementwise
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 09c3a781
......@@ -19,16 +19,25 @@ limitations under the License.
namespace oneflow {
template<template<typename> class BIN_OP, typename T>
__global__ void DoCUDAScalarMath(const int64_t elem_cnt, const T scalar, const T* in, T* out) {
DoScalarMath<BIN_OP, T>(elem_cnt, scalar, in, out);
}
template<template<typename> typename Op, typename T>
struct UnaryByScalarFunctor {
__host__ __device__ explicit UnaryByScalarFunctor(T scalar) : scalar(scalar) {}
__device__ T operator()(T a) const { return Op<T>::Invoke(a, scalar); }
const T scalar;
};
template<template<typename> typename Op>
struct UnaryByScalarFunctor<Op, float16> {
__host__ __device__ explicit UnaryByScalarFunctor(half scalar) : scalar(scalar) {}
__device__ half operator()(half a) const { return Op<half>::Invoke(a, scalar); }
const half scalar;
};
template<template<typename> class BIN_OP, typename T>
struct ScalarMathFunctor<DeviceType::kGPU, BIN_OP, T> final {
void operator()(DeviceCtx* ctx, const int64_t elem_cnt, const T scalar, const T* in, T* out) {
RUN_CUDA_KERNEL((DoCUDAScalarMath<BIN_OP, T>), ctx, BlocksNum4ThreadsNum(elem_cnt), elem_cnt,
scalar, in, out);
OF_CUDA_CHECK(cuda::elementwise::Unary(UnaryByScalarFunctor<BIN_OP, T>(scalar), elem_cnt, out,
in, ctx->cuda_stream()));
}
};
......@@ -36,9 +45,9 @@ template<template<typename> class BIN_OP>
struct ScalarMathFunctor<DeviceType::kGPU, BIN_OP, float16> final {
void operator()(DeviceCtx* ctx, const int64_t elem_cnt, float16 scalar, const float16* in,
float16* out) {
RUN_CUDA_KERNEL((DoCUDAScalarMath<BIN_OP, half>), ctx, BlocksNum4ThreadsNum(elem_cnt), elem_cnt,
float16_2half(scalar), reinterpret_cast<const half*>(in),
reinterpret_cast<half*>(out));
OF_CUDA_CHECK(cuda::elementwise::Unary(
UnaryByScalarFunctor<BIN_OP, float16>(float16_2half(scalar)), elem_cnt,
reinterpret_cast<half*>(out), reinterpret_cast<const half*>(in), ctx->cuda_stream()));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册