提交 5d777454 编写于 作者: J Jinhui Yuan 提交者: GitHub

speedup clone kernel (#1017)

* speedup clone kernel

* refine

* simplify

* handle with arbitrary number of out_diff


Former-commit-id: 9d0bae73
上级 1248e065
......@@ -15,40 +15,218 @@ void CloneKernel<device_type, T>::Forward(
template<DeviceType device_type, typename T>
struct CloneKernelUtil {
// b += a
static void AdditionAssign(DeviceCtx* device_ctx, const Blob* a, Blob* b);
// out += in
static void AdditionAssign(DeviceCtx* device_ctx, const int64_t elem_cnt, Blob* out,
const Blob* in);
};
template<DeviceType device_type, typename T>
void CloneKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const PbRpf<std::string>& odbns = this->op_attribute().output_diff_bns();
if (odbns.size() == 0) return;
size_t out_num = odbns.size();
if (out_num == 0) return;
Blob* in_diff_blob = BnInOp2Blob(this->op_attribute().input_diff_bns(0));
const Blob* out_diff_blob_0 = BnInOp2Blob(odbns[0]);
Memcpy<device_type>(ctx.device_ctx, in_diff_blob->mut_dptr(), out_diff_blob_0->dptr(),
out_diff_blob_0->ByteSizeOfDataContentField());
for (size_t i = 1; i != odbns.size(); ++i) {
const Blob* out_diff_blob = BnInOp2Blob(odbns[i]);
CloneKernelUtil<device_type, T>::AdditionAssign(ctx.device_ctx, out_diff_blob, in_diff_blob);
Memset<device_type>(ctx.device_ctx, in_diff_blob->mut_dptr<T>(), 0,
in_diff_blob->ByteSizeOfDataContentField());
auto out_diff = [&](int32_t idx) {
return BnInOp2Blob(this->op_attribute().output_diff_bns(idx));
};
int32_t offset = 0;
while (out_num - offset >= 10) {
CloneKernelUtil<device_type, T>::AdditionAssign(
ctx.device_ctx, in_diff_blob, out_diff(offset), out_diff(offset + 1), out_diff(offset + 2),
out_diff(offset + 3), out_diff(offset + 4), out_diff(offset + 5), out_diff(offset + 6),
out_diff(offset + 7), out_diff(offset + 8), out_diff(offset + 9));
offset += 10;
}
if (out_num - offset > 0) {
switch (out_num - offset) {
case 1:
CloneKernelUtil<device_type, T>::AdditionAssign(ctx.device_ctx, in_diff_blob,
out_diff(offset));
break;
case 2:
CloneKernelUtil<device_type, T>::AdditionAssign(ctx.device_ctx, in_diff_blob,
out_diff(offset), out_diff(offset + 1));
break;
case 3:
CloneKernelUtil<device_type, T>::AdditionAssign(ctx.device_ctx, in_diff_blob,
out_diff(offset), out_diff(offset + 1),
out_diff(offset + 2));
break;
case 4:
CloneKernelUtil<device_type, T>::AdditionAssign(ctx.device_ctx, in_diff_blob,
out_diff(offset), out_diff(offset + 1),
out_diff(offset + 2), out_diff(offset + 3));
break;
case 5:
CloneKernelUtil<device_type, T>::AdditionAssign(
ctx.device_ctx, in_diff_blob, out_diff(offset), out_diff(offset + 1),
out_diff(offset + 2), out_diff(offset + 3), out_diff(offset + 4));
break;
case 6:
CloneKernelUtil<device_type, T>::AdditionAssign(
ctx.device_ctx, in_diff_blob, out_diff(offset), out_diff(offset + 1),
out_diff(offset + 2), out_diff(offset + 3), out_diff(offset + 4), out_diff(offset + 5));
break;
case 7:
CloneKernelUtil<device_type, T>::AdditionAssign(
ctx.device_ctx, in_diff_blob, out_diff(offset), out_diff(offset + 1),
out_diff(offset + 2), out_diff(offset + 3), out_diff(offset + 4), out_diff(offset + 5),
out_diff(offset + 6));
break;
case 8:
CloneKernelUtil<device_type, T>::AdditionAssign(
ctx.device_ctx, in_diff_blob, out_diff(offset), out_diff(offset + 1),
out_diff(offset + 2), out_diff(offset + 3), out_diff(offset + 4), out_diff(offset + 5),
out_diff(offset + 6), out_diff(offset + 7));
break;
case 9:
CloneKernelUtil<device_type, T>::AdditionAssign(
ctx.device_ctx, in_diff_blob, out_diff(offset), out_diff(offset + 1),
out_diff(offset + 2), out_diff(offset + 3), out_diff(offset + 4), out_diff(offset + 5),
out_diff(offset + 6), out_diff(offset + 7), out_diff(offset + 8));
break;
}
}
}
#define DEFINE_FLOATING_CLONE_KERNEL_UTIL(type_cpp, type_proto) \
template<DeviceType device_type> \
struct CloneKernelUtil<device_type, type_cpp> { \
static void AdditionAssign(DeviceCtx* device_ctx, const Blob* a, Blob* b) { \
KernelUtil<device_type, type_cpp>::Axpy(device_ctx, a->shape().elem_cnt(), 1.0, \
a->dptr<type_cpp>(), 1, b->mut_dptr<type_cpp>(), 1); \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0) { \
KernelUtil<device_type, type_cpp>::AdditionAssign( \
device_ctx, out->shape().elem_cnt(), out->mut_dptr<type_cpp>(), in_0->dptr<type_cpp>()); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1) { \
KernelUtil<device_type, type_cpp>::AdditionAssign( \
device_ctx, out->shape().elem_cnt(), out->mut_dptr<type_cpp>(), in_0->dptr<type_cpp>(), \
in_1->dptr<type_cpp>()); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2) { \
KernelUtil<device_type, type_cpp>::AdditionAssign( \
device_ctx, out->shape().elem_cnt(), out->mut_dptr<type_cpp>(), in_0->dptr<type_cpp>(), \
in_1->dptr<type_cpp>(), in_2->dptr<type_cpp>()); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3) { \
KernelUtil<device_type, type_cpp>::AdditionAssign( \
device_ctx, out->shape().elem_cnt(), out->mut_dptr<type_cpp>(), in_0->dptr<type_cpp>(), \
in_1->dptr<type_cpp>(), in_2->dptr<type_cpp>(), in_3->dptr<type_cpp>()); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4) { \
KernelUtil<device_type, type_cpp>::AdditionAssign( \
device_ctx, out->shape().elem_cnt(), out->mut_dptr<type_cpp>(), in_0->dptr<type_cpp>(), \
in_1->dptr<type_cpp>(), in_2->dptr<type_cpp>(), in_3->dptr<type_cpp>(), \
in_4->dptr<type_cpp>()); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4, const Blob* in_5) { \
KernelUtil<device_type, type_cpp>::AdditionAssign( \
device_ctx, out->shape().elem_cnt(), out->mut_dptr<type_cpp>(), in_0->dptr<type_cpp>(), \
in_1->dptr<type_cpp>(), in_2->dptr<type_cpp>(), in_3->dptr<type_cpp>(), \
in_4->dptr<type_cpp>(), in_5->dptr<type_cpp>()); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4, const Blob* in_5, const Blob* in_6) { \
KernelUtil<device_type, type_cpp>::AdditionAssign( \
device_ctx, out->shape().elem_cnt(), out->mut_dptr<type_cpp>(), in_0->dptr<type_cpp>(), \
in_1->dptr<type_cpp>(), in_2->dptr<type_cpp>(), in_3->dptr<type_cpp>(), \
in_4->dptr<type_cpp>(), in_5->dptr<type_cpp>(), in_6->dptr<type_cpp>()); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4, const Blob* in_5, const Blob* in_6, \
const Blob* in_7) { \
KernelUtil<device_type, type_cpp>::AdditionAssign( \
device_ctx, out->shape().elem_cnt(), out->mut_dptr<type_cpp>(), in_0->dptr<type_cpp>(), \
in_1->dptr<type_cpp>(), in_2->dptr<type_cpp>(), in_3->dptr<type_cpp>(), \
in_4->dptr<type_cpp>(), in_5->dptr<type_cpp>(), in_6->dptr<type_cpp>(), \
in_7->dptr<type_cpp>()); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4, const Blob* in_5, const Blob* in_6, \
const Blob* in_7, const Blob* in_8) { \
KernelUtil<device_type, type_cpp>::AdditionAssign( \
device_ctx, out->shape().elem_cnt(), out->mut_dptr<type_cpp>(), in_0->dptr<type_cpp>(), \
in_1->dptr<type_cpp>(), in_2->dptr<type_cpp>(), in_3->dptr<type_cpp>(), \
in_4->dptr<type_cpp>(), in_5->dptr<type_cpp>(), in_6->dptr<type_cpp>(), \
in_7->dptr<type_cpp>(), in_8->dptr<type_cpp>()); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4, const Blob* in_5, const Blob* in_6, \
const Blob* in_7, const Blob* in_8, const Blob* in_9) { \
KernelUtil<device_type, type_cpp>::AdditionAssign( \
device_ctx, out->shape().elem_cnt(), out->mut_dptr<type_cpp>(), in_0->dptr<type_cpp>(), \
in_1->dptr<type_cpp>(), in_2->dptr<type_cpp>(), in_3->dptr<type_cpp>(), \
in_4->dptr<type_cpp>(), in_5->dptr<type_cpp>(), in_6->dptr<type_cpp>(), \
in_7->dptr<type_cpp>(), in_8->dptr<type_cpp>(), in_9->dptr<type_cpp>()); \
} \
};
OF_PP_FOR_EACH_TUPLE(DEFINE_FLOATING_CLONE_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ)
#define DEFINE_NONFLOAT_CLONE_KERNEL_UTIL(type_cpp, type_proto) \
template<DeviceType device_type> \
struct CloneKernelUtil<device_type, type_cpp> { \
static void AdditionAssign(DeviceCtx* device_ctx, const Blob* a, Blob* b) { UNIMPLEMENTED(); } \
#define DEFINE_NONFLOAT_CLONE_KERNEL_UTIL(type_cpp, type_proto) \
template<DeviceType device_type> \
struct CloneKernelUtil<device_type, type_cpp> { \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0) { \
UNIMPLEMENTED(); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1) { \
UNIMPLEMENTED(); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2) { \
UNIMPLEMENTED(); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3) { \
UNIMPLEMENTED(); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4) { \
UNIMPLEMENTED(); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4, const Blob* in_5) { \
UNIMPLEMENTED(); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4, const Blob* in_5, const Blob* in_6) { \
UNIMPLEMENTED(); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4, const Blob* in_5, const Blob* in_6, \
const Blob* in_7) { \
UNIMPLEMENTED(); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4, const Blob* in_5, const Blob* in_6, \
const Blob* in_7, const Blob* in_8) { \
UNIMPLEMENTED(); \
} \
static void AdditionAssign(DeviceCtx* device_ctx, Blob* out, const Blob* in_0, \
const Blob* in_1, const Blob* in_2, const Blob* in_3, \
const Blob* in_4, const Blob* in_5, const Blob* in_6, \
const Blob* in_7, const Blob* in_8, const Blob* in_9) { \
UNIMPLEMENTED(); \
} \
};
OF_PP_FOR_EACH_TUPLE(DEFINE_NONFLOAT_CLONE_KERNEL_UTIL, INT_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ)
......
......@@ -352,6 +352,63 @@ KU_FLOATING_METHOD ReluBackward(DeviceCtx* ctx, const int64_t n, const T* x, con
T zero = ZeroVal<T>::value;
for (int64_t i = 0; i != n; ++i) { dx[i] = (y[i] > zero) * dy[i]; }
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0) {
for (int64_t i = 0; i != n; ++i) { out[i] += in_0[i]; }
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1) {
for (int64_t i = 0; i != n; ++i) { out[i] += in_0[i] + in_1[i]; }
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2) {
for (int64_t i = 0; i != n; ++i) { out[i] += in_0[i] + in_1[i] + in_2[i]; }
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3) {
for (int64_t i = 0; i != n; ++i) { out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i]; }
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4) {
for (int64_t i = 0; i != n; ++i) { out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i]; }
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4,
const T* in_5) {
for (int64_t i = 0; i != n; ++i) {
out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i] + in_5[i];
}
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4,
const T* in_5, const T* in_6) {
for (int64_t i = 0; i != n; ++i) {
out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i] + in_5[i] + in_6[i];
}
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4,
const T* in_5, const T* in_6, const T* in_7) {
for (int64_t i = 0; i != n; ++i) {
out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i] + in_5[i] + in_6[i] + in_7[i];
}
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4,
const T* in_5, const T* in_6, const T* in_7, const T* in_8) {
for (int64_t i = 0; i != n; ++i) {
out[i] +=
in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i] + in_5[i] + in_6[i] + in_7[i] + in_8[i];
}
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4,
const T* in_5, const T* in_6, const T* in_7, const T* in_8,
const T* in_9) {
for (int64_t i = 0; i != n; ++i) {
out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i] + in_5[i] + in_6[i] + in_7[i]
+ in_8[i] + in_9[i];
}
}
KU_FLOATING_METHOD InitializeWithConf(DeviceCtx* ctx, const InitializerConf& initializer_conf,
uint32_t random_seed, Blob* blob) {
......
......@@ -65,6 +65,76 @@ __global__ void ReluBackwardGpu(const int n, const T* y, const T* dy, T* dx) {
CUDA_1D_KERNEL_LOOP(i, n) { dx[i] = y[i] > 0 ? dy[i] : 0; }
}
template<typename T>
__global__ void gpu_add_assign(const int64_t n, T* out, const T* in_0) {
CUDA_1D_KERNEL_LOOP(i, n) { out[i] += in_0[i]; }
}
template<typename T>
__global__ void gpu_add_assign(const int64_t n, T* out, const T* in_0, const T* in_1) {
CUDA_1D_KERNEL_LOOP(i, n) { out[i] += in_0[i] + in_1[i]; }
}
template<typename T>
__global__ void gpu_add_assign(const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2) {
CUDA_1D_KERNEL_LOOP(i, n) { out[i] += in_0[i] + in_1[i] + in_2[i]; }
}
template<typename T>
__global__ void gpu_add_assign(const int64_t n, T* out, const T* in_0, const T* in_1, const T* in_2,
const T* in_3) {
CUDA_1D_KERNEL_LOOP(i, n) { out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i]; }
}
template<typename T>
__global__ void gpu_add_assign(const int64_t n, T* out, const T* in_0, const T* in_1, const T* in_2,
const T* in_3, const T* in_4) {
CUDA_1D_KERNEL_LOOP(i, n) { out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i]; }
}
template<typename T>
__global__ void gpu_add_assign(const int64_t n, T* out, const T* in_0, const T* in_1, const T* in_2,
const T* in_3, const T* in_4, const T* in_5) {
CUDA_1D_KERNEL_LOOP(i, n) { out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i] + in_5[i]; }
}
template<typename T>
__global__ void gpu_add_assign(const int64_t n, T* out, const T* in_0, const T* in_1, const T* in_2,
const T* in_3, const T* in_4, const T* in_5, const T* in_6) {
CUDA_1D_KERNEL_LOOP(i, n) {
out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i] + in_5[i] + in_6[i];
}
}
template<typename T>
__global__ void gpu_add_assign(const int64_t n, T* out, const T* in_0, const T* in_1, const T* in_2,
const T* in_3, const T* in_4, const T* in_5, const T* in_6,
const T* in_7) {
CUDA_1D_KERNEL_LOOP(i, n) {
out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i] + in_5[i] + in_6[i] + in_7[i];
}
}
template<typename T>
__global__ void gpu_add_assign(const int64_t n, T* out, const T* in_0, const T* in_1, const T* in_2,
const T* in_3, const T* in_4, const T* in_5, const T* in_6,
const T* in_7, const T* in_8) {
CUDA_1D_KERNEL_LOOP(i, n) {
out[i] +=
in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i] + in_5[i] + in_6[i] + in_7[i] + in_8[i];
}
}
template<typename T>
__global__ void gpu_add_assign(const int64_t n, T* out, const T* in_0, const T* in_1, const T* in_2,
const T* in_3, const T* in_4, const T* in_5, const T* in_6,
const T* in_7, const T* in_8, const T* in_9) {
CUDA_1D_KERNEL_LOOP(i, n) {
out[i] += in_0[i] + in_1[i] + in_2[i] + in_3[i] + in_4[i] + in_5[i] + in_6[i] + in_7[i]
+ in_8[i] + in_9[i];
}
}
cublasOperation_t CblasTrans2CublasTrans(CBLAS_TRANSPOSE trans) {
cublasOperation_t cublas_trans;
if (trans == CBLAS_TRANSPOSE::CblasNoTrans) {
......@@ -416,6 +486,70 @@ KU_FLOATING_METHOD ReluBackward(DeviceCtx* ctx, const int64_t n, const T* x, con
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, y, dy, dx);
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0) {
gpu_add_assign<T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, out, in_0);
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1) {
gpu_add_assign<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, out, in_0, in_1);
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2) {
gpu_add_assign<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, out, in_0, in_1, in_2);
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3) {
gpu_add_assign<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, out, in_0, in_1, in_2, in_3);
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4) {
gpu_add_assign<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, out, in_0, in_1, in_2, in_3, in_4);
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4,
const T* in_5) {
gpu_add_assign<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, out, in_0, in_1, in_2, in_3, in_4, in_5);
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4,
const T* in_5, const T* in_6) {
gpu_add_assign<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, out, in_0, in_1, in_2, in_3, in_4, in_5, in_6);
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4,
const T* in_5, const T* in_6, const T* in_7) {
gpu_add_assign<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, out, in_0, in_1, in_2, in_3, in_4, in_5, in_6, in_7);
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4,
const T* in_5, const T* in_6, const T* in_7, const T* in_8) {
gpu_add_assign<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, out, in_0, in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8);
}
KU_FLOATING_METHOD AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0,
const T* in_1, const T* in_2, const T* in_3, const T* in_4,
const T* in_5, const T* in_6, const T* in_7, const T* in_8,
const T* in_9) {
gpu_add_assign<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, out, in_0, in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9);
}
#define KU_INTEGRAL_METHOD \
template<typename T> \
void KernelUtil<DeviceType::kGPU, T, typename std::enable_if<IsIntegral<T>::value>::type>::
......
......@@ -179,6 +179,29 @@ struct KernelUtil<DeviceType::kCPU, T, typename std::enable_if<IsFloating<T>::va
static void ReluBackward(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, const T* dy,
T* dx);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4, const T* in_5);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4, const T* in_5,
const T* in_6);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4, const T* in_5,
const T* in_6, const T* in_7);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4, const T* in_5,
const T* in_6, const T* in_7, const T* in_8);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4, const T* in_5,
const T* in_6, const T* in_7, const T* in_8, const T* in_9);
static void InitializeWithConf(DeviceCtx* ctx, const InitializerConf& initializer_conf,
uint32_t random_seed, Blob* blob);
static void InitializeWithConf(DeviceCtx* ctx, const InitializerConf& initializer_conf,
......@@ -261,6 +284,29 @@ struct KernelUtil<DeviceType::kGPU, T, typename std::enable_if<IsFloating<T>::va
static void Relu(DeviceCtx* ctx, int64_t n, const T* x, T* y);
static void ReluBackward(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, const T* dy,
T* dx);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4, const T* in_5);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4, const T* in_5,
const T* in_6);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4, const T* in_5,
const T* in_6, const T* in_7);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4, const T* in_5,
const T* in_6, const T* in_7, const T* in_8);
static void AdditionAssign(DeviceCtx* ctx, const int64_t n, T* out, const T* in_0, const T* in_1,
const T* in_2, const T* in_3, const T* in_4, const T* in_5,
const T* in_6, const T* in_7, const T* in_8, const T* in_9);
};
// GPU, Integral
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册