未验证 提交 3419de53 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support different data type between input and output (#32823)

上级 fbbc3394
...@@ -13,44 +13,79 @@ ...@@ -13,44 +13,79 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/abs_op.h" #include "paddle/fluid/operators/abs_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T, typename Enable = void>
struct CudaAbsFunctor;
template <typename T>
struct CudaAbsFunctor<T, math::Complex<T, math::Real<T>>> {
__device__ __forceinline__ math::Real<T> operator()(const T* args) const {
return abs(args[0]);
}
};
template <typename T>
struct CudaAbsFunctor<T, math::NoComplex<T, math::Real<T>>> {
__device__ __forceinline__ T operator()(const T* args) const {
return std::abs(args[0]);
}
};
template <typename T>
class AbsKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
out->mutable_data<math::Real<T>>(context.GetPlace());
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
auto functor = CudaAbsFunctor<T>();
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, math::Real<T>>(
dev_ctx, ins, &outs, functor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
abs, ops::AbsKernel<paddle::platform::CUDADeviceContext, float>, abs, ops::AbsKernel<plat::CUDADeviceContext, float>,
ops::AbsKernel<paddle::platform::CUDADeviceContext, double>, ops::AbsKernel<plat::CUDADeviceContext, double>,
ops::AbsKernel<paddle::platform::CUDADeviceContext, int>, ops::AbsKernel<plat::CUDADeviceContext, int>,
ops::AbsKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::AbsKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsKernel<paddle::platform::CUDADeviceContext, ops::AbsKernel<plat::CUDADeviceContext, plat::float16>,
paddle::platform::float16>, ops::AbsKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsKernel<paddle::platform::CUDADeviceContext, ops::AbsKernel<plat::CUDADeviceContext, plat::complex128>);
paddle::platform::complex64>,
ops::AbsKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
abs_grad, ops::AbsGradKernel<paddle::platform::CUDADeviceContext, float>, abs_grad, ops::AbsGradKernel<plat::CUDADeviceContext, float>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, double>, ops::AbsGradKernel<plat::CUDADeviceContext, double>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, int>, ops::AbsGradKernel<plat::CUDADeviceContext, int>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::AbsGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, ops::AbsGradKernel<plat::CUDADeviceContext, plat::float16>,
paddle::platform::float16>, ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext, ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex128>);
paddle::platform::complex64>,
ops::AbsGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
abs_grad_grad, abs_grad_grad, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, float>, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, double>, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, int>, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex64>,
paddle::platform::float16>, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::AbsDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
...@@ -1315,8 +1315,8 @@ class ActivationCudaKernel ...@@ -1315,8 +1315,8 @@ class ActivationCudaKernel
for (auto& attr : attrs) { for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first); *attr.second = ctx.Attr<float>(attr.first);
} }
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T>(dev_ctx, ins, &outs, LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(dev_ctx, ins,
functor); &outs, functor);
} }
}; };
...@@ -1345,17 +1345,17 @@ class ActivationGradCudaKernel ...@@ -1345,17 +1345,17 @@ class ActivationGradCudaKernel
if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) { if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) {
// Only need forward output Out // Only need forward output Out
ins.push_back(out); ins.push_back(out);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T>(dev_ctx, ins, LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
&outs, functor); dev_ctx, ins, &outs, functor);
} else if (static_cast<int>(Functor::FwdDeps()) == } else if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(kDepX)) { static_cast<int>(kDepX)) {
// Only need forward input X // Only need forward input X
ins.push_back(x); ins.push_back(x);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T>(dev_ctx, ins, LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
&outs, functor); dev_ctx, ins, &outs, functor);
} else { } else {
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T>(dev_ctx, ins, LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
&outs, functor); dev_ctx, ins, &outs, functor);
} }
} }
}; };
......
...@@ -45,7 +45,7 @@ struct SameDimsElemwiseAdd<platform::CUDADeviceContext, T> { ...@@ -45,7 +45,7 @@ struct SameDimsElemwiseAdd<platform::CUDADeviceContext, T> {
framework::Tensor* z) { framework::Tensor* z) {
std::vector<const framework::Tensor*> ins = {x, y}; std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {z}; std::vector<framework::Tensor*> outs = {z};
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T>( LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx.template device_context<platform::CUDADeviceContext>(), ins, &outs, ctx.template device_context<platform::CUDADeviceContext>(), ins, &outs,
CudaAddFunctor<T>()); CudaAddFunctor<T>());
} }
......
...@@ -49,69 +49,73 @@ int GetVectorizedSizeImpl(const T *pointer) { ...@@ -49,69 +49,73 @@ int GetVectorizedSizeImpl(const T *pointer) {
return 1; return 1;
} }
template <typename T> template <typename InT, typename OutT>
int GetVectorizedSize(const std::vector<const framework::Tensor *> &ins, int GetVectorizedSize(const std::vector<const framework::Tensor *> &ins,
const std::vector<framework::Tensor *> &outs) { const std::vector<framework::Tensor *> &outs) {
int vec_size = 4; int vec_size = 4;
for (auto iter = ins.begin(); iter != ins.end(); ++iter) { for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
vec_size = vec_size =
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<T>())); std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<InT>()));
} }
for (auto iter = outs.begin(); iter != outs.end(); ++iter) { for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
vec_size = vec_size =
std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<T>())); std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<OutT>()));
} }
return vec_size; return vec_size;
} }
template <ElementwiseType ET, int VecSize, typename T> template <ElementwiseType ET, int VecSize, typename InT, typename OutT>
struct ElementwiseDataWrapper { struct ElementwiseDataWrapper {
T *out; OutT *out;
const T *in0; const InT *in0;
const T *in1; const InT *in1;
__device__ ElementwiseDataWrapper(T *out, const T *in0, __device__ ElementwiseDataWrapper(OutT *out, const InT *in0,
const T *in1 = nullptr) const InT *in1 = nullptr)
: out(out), in0(in0), in1(in1) {} : out(out), in0(in0), in1(in1) {}
using VecType = CudaAlignedVector<T, VecSize>; using InVecType = CudaAlignedVector<InT, VecSize>;
using OutVecType = CudaAlignedVector<OutT, VecSize>;
inline __device__ void load_vector(VecType args[], int idx) { inline __device__ void load_vector(InVecType args[], int idx) {
const VecType *x_vec = reinterpret_cast<const VecType *>(in0); const InVecType *x_vec = reinterpret_cast<const InVecType *>(in0);
args[0] = x_vec[idx]; args[0] = x_vec[idx];
if (ET == ElementwiseType::kBinary) { if (ET == ElementwiseType::kBinary) {
const VecType *y_vec = reinterpret_cast<const VecType *>(in1); const InVecType *y_vec = reinterpret_cast<const InVecType *>(in1);
args[1] = y_vec[idx]; args[1] = y_vec[idx];
} }
} }
inline __device__ void load_scalar(T args[], int idx) { inline __device__ void load_scalar(InT args[], int idx) {
args[0] = in0[idx]; args[0] = in0[idx];
if (ET == ElementwiseType::kBinary) { if (ET == ElementwiseType::kBinary) {
args[1] = in1[idx]; args[1] = in1[idx];
} }
} }
inline __device__ void store_vector(VecType res, int idx) { inline __device__ void store_vector(OutVecType res, int idx) {
VecType *out_vec = reinterpret_cast<VecType *>(out); OutVecType *out_vec = reinterpret_cast<OutVecType *>(out);
out_vec[idx] = res; out_vec[idx] = res;
} }
inline __device__ void store_scalar(T res, int idx) { out[idx] = res; } inline __device__ void store_scalar(OutT res, int idx) { out[idx] = res; }
}; };
template <ElementwiseType ET, int VecSize, typename T, typename Functor> template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__device__ void VectorizedKernelImpl( __device__ void VectorizedKernelImpl(
ElementwiseDataWrapper<ET, VecSize, T> data, Functor func, int tid) { ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
using VecType = CudaAlignedVector<T, VecSize>; int tid) {
VecType ins_vec[ET]; using InVecType = CudaAlignedVector<InT, VecSize>;
VecType out_vec; using OutVecType = CudaAlignedVector<OutT, VecSize>;
T *ins_ptr[ET]; InVecType ins_vec[ET];
T *out_ptr; OutVecType out_vec;
InT *ins_ptr[ET];
OutT *out_ptr;
#pragma unroll #pragma unroll
for (int i = 0; i < ET; ++i) { for (int i = 0; i < ET; ++i) {
ins_ptr[i] = reinterpret_cast<T *>(&(ins_vec[i])); ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
} }
out_ptr = reinterpret_cast<T *>(&out_vec); out_ptr = reinterpret_cast<OutT *>(&out_vec);
// load // load
data.load_vector(ins_vec, tid); data.load_vector(ins_vec, tid);
...@@ -119,7 +123,7 @@ __device__ void VectorizedKernelImpl( ...@@ -119,7 +123,7 @@ __device__ void VectorizedKernelImpl(
// compute // compute
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; ++i) { for (int i = 0; i < VecSize; ++i) {
T ins[ET]; InT ins[ET];
#pragma unroll #pragma unroll
for (int j = 0; j < ET; ++j) { for (int j = 0; j < ET; ++j) {
ins[j] = ins_ptr[j][i]; ins[j] = ins_ptr[j][i];
...@@ -131,11 +135,13 @@ __device__ void VectorizedKernelImpl( ...@@ -131,11 +135,13 @@ __device__ void VectorizedKernelImpl(
data.store_vector(out_vec, tid); data.store_vector(out_vec, tid);
} }
template <ElementwiseType ET, int VecSize, typename T, typename Functor> template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
__device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, VecSize, T> data, typename Functor>
Functor func, int start, int remain) { __device__ void ScalarKernelImpl(
T ins[ET]; ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
T out; int start, int remain) {
InT ins[ET];
OutT out;
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
int idx = start + i; int idx = start + i;
...@@ -148,14 +154,15 @@ __device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, VecSize, T> data, ...@@ -148,14 +154,15 @@ __device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, VecSize, T> data,
} }
} }
template <ElementwiseType ET, int VecSize, typename T, typename Functor> template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
__global__ void VectorizedKernel(const T *__restrict__ in0, typename Functor>
const T *__restrict__ in1, T *out, int size, __global__ void VectorizedKernel(const InT *__restrict__ in0,
Functor func) { const InT *__restrict__ in1, OutT *out,
int size, Functor func) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int remain = size - VecSize * tid; int remain = size - VecSize * tid;
remain = remain > 0 ? remain : 0; remain = remain > 0 ? remain : 0;
auto data = ElementwiseDataWrapper<ET, VecSize, T>(out, in0, in1); auto data = ElementwiseDataWrapper<ET, VecSize, InT, OutT>(out, in0, in1);
if (remain >= VecSize) { if (remain >= VecSize) {
VectorizedKernelImpl(data, func, tid); VectorizedKernelImpl(data, func, tid);
} else { } else {
...@@ -163,30 +170,31 @@ __global__ void VectorizedKernel(const T *__restrict__ in0, ...@@ -163,30 +170,31 @@ __global__ void VectorizedKernel(const T *__restrict__ in0,
} }
} }
template <ElementwiseType ET, typename T, typename Functor> template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
__global__ void ScalarKernel(const T *__restrict__ in0, __global__ void ScalarKernel(const InT *__restrict__ in0,
const T *__restrict__ in1, T *out, int size, const InT *__restrict__ in1, OutT *out, int size,
Functor func) { Functor func) {
auto data = ElementwiseDataWrapper<ET, 1, T>(out, in0, in1); auto data = ElementwiseDataWrapper<ET, 1, InT, OutT>(out, in0, in1);
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int remain = tid < size ? 1 : 0; int remain = tid < size ? 1 : 0;
ScalarKernelImpl(data, func, tid, remain); ScalarKernelImpl(data, func, tid, remain);
} }
template <ElementwiseType ET, typename T, typename Functor> template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchElementwiseCudaKernel( void LaunchElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx, const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) { std::vector<framework::Tensor *> *outs, Functor func) {
// calculate the max vec_size for all ins and outs // calculate the max vec_size for all ins and outs
auto size = ins[0]->numel(); auto size = ins[0]->numel();
int vec_size = GetVectorizedSize<T>(ins, *outs); int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs);
int block_size = ELEMENTWISE_BLOCK_SIZE; int block_size = ELEMENTWISE_BLOCK_SIZE;
int grid_size = int grid_size =
((size + vec_size - 1) / vec_size + block_size - 1) / block_size; ((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
const T *in0 = ins[0]->data<T>(); const InT *in0 = ins[0]->data<InT>();
const T *in1 = (ET == ElementwiseType::kBinary) ? ins[1]->data<T>() : nullptr; const InT *in1 =
T *out = (*outs)[0]->data<T>(); (ET == ElementwiseType::kBinary) ? ins[1]->data<InT>() : nullptr;
OutT *out = (*outs)[0]->data<OutT>();
// cuda kernel // cuda kernel
auto stream = ctx.stream(); auto stream = ctx.stream();
switch (vec_size) { switch (vec_size) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册