提交 e7f176d9 编写于 作者: C chengtbf 提交者: Jingwu Chen

softmax kernel and some functions for kernel util (#194)

* softmax kernel and some functions for kernel util

* fix bug when using KernelUtil::BlasAxpy
上级 9628cb31
......@@ -76,6 +76,44 @@ class KernelUtil<DeviceType::kCPU, FloatingPointType> final {
[n, alpha, x, incx]() { cblas_scal(n, alpha, x, incx); });
}
static void Max(const KernelCtx& ctx, const int64_t n,
const FloatingPointType* x, FloatingPointType* max_ptr) {
ctx.device_ctx->cpu_stream()->SendWork([=]() {
*max_ptr = x[0];
for (int64_t i = 0; i < n; ++i) { *max_ptr = std::max(*max_ptr, x[i]); }
});
}
static void Exp(const KernelCtx& ctx, const int64_t n,
const FloatingPointType* x, FloatingPointType* y) {
ctx.device_ctx->cpu_stream()->SendWork([=]() {
for (int64_t i = 0; i < n; ++i) { y[i] = std::exp(x[i]); }
});
}
static void Sum(const KernelCtx& ctx, const int64_t n,
const FloatingPointType* x, FloatingPointType* sum_ptr) {
ctx.device_ctx->cpu_stream()->SendWork([=]() {
*sum_ptr = 0;
for (int64_t i = 0; i < n; ++i) { *sum_ptr += x[i]; }
});
}
static void Div(const KernelCtx& ctx, const int64_t n, FloatingPointType* x,
const FloatingPointType alpha) {
ctx.device_ctx->cpu_stream()->SendWork([=]() {
for (int64_t i = 0; i < n; ++i) { x[i] = x[i] / alpha; }
});
}
static void Mul(const KernelCtx& ctx, const int64_t n,
const FloatingPointType* x, const FloatingPointType* y,
FloatingPointType* z) {
ctx.device_ctx->cpu_stream()->SendWork([=]() {
for (int64_t i = 0; i < n; ++i) { z[i] = x[i] * y[i]; }
});
}
static void BlasGemv(const KernelCtx& ctx, const enum CBLAS_TRANSPOSE trans,
int m, int n, const FloatingPointType alpha,
const FloatingPointType* a, int lda,
......
......@@ -4,6 +4,28 @@
namespace oneflow {
namespace {
template<typename FloatingPointType>
__global__ void ExpGpu(const int64_t n, const FloatingPointType* x,
FloatingPointType* y) {
CUDA_1D_KERNEL_LOOP(i, n) { y[i] = std::exp(x[i]); }
}
template<typename FloatingPointType>
__global__ void DivGpu(const int64_t n, FloatingPointType* x,
const FloatingPointType alpha) {
CUDA_1D_KERNEL_LOOP(i, n) { x[i] = x[i] / alpha; }
}
template<typename FloatingPointType>
__global__ void MulGpu(const int64_t n, const FloatingPointType* x,
const FloatingPointType* y, FloatingPointType* z) {
CUDA_1D_KERNEL_LOOP(i, n) { z[i] = x[i] * y[i]; }
}
} // namespace
template<typename FloatingPointType>
class KernelUtil<DeviceType::kGPU, FloatingPointType> final {
public:
......@@ -34,6 +56,28 @@ class KernelUtil<DeviceType::kGPU, FloatingPointType> final {
cublas_scal(ctx.device_ctx->cublas_handle(), n, &alpha, x, incx);
}
static void Exp(const KernelCtx& ctx, const int64_t n,
const FloatingPointType* x, FloatingPointType* y) {
ExpGpu<FloatingPointType>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(n, x, y);
}
static void Div(const KernelCtx& ctx, const int64_t n, FloatingPointType* x,
const FloatingPointType alpha) {
DivGpu<FloatingPointType>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(n, x, alpha);
}
static void Mul(const KernelCtx& ctx, const int64_t n,
const FloatingPointType* x, const FloatingPointType* y,
FloatingPointType* z) {
MulGpu<FloatingPointType>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(n, x, y, z);
}
static void BlasGemv(const KernelCtx& ctx, const enum CBLAS_TRANSPOSE trans,
int m, int n, const FloatingPointType alpha,
const FloatingPointType* a, int lda,
......
......@@ -51,6 +51,29 @@ class KernelUtil final {
static void BlasScal(const KernelCtx& ctx, const int n,
const FloatingPointType alpha, FloatingPointType* x,
const int incx);
// max(x)
// NO template specialization for GPU
static void Max(const KernelCtx& ctx, const int64_t n,
const FloatingPointType* x, FloatingPointType* max_ptr);
// y = exp(x)
static void Exp(const KernelCtx& ctx, const int64_t n,
const FloatingPointType* x, FloatingPointType* y);
// sum(x)
// NO template specialization for GPU
static void Sum(const KernelCtx& ctx, const int64_t n,
const FloatingPointType* x, FloatingPointType* sum_ptr);
// x = x / a
static void Div(const KernelCtx& ctx, const int64_t n, FloatingPointType* x,
const FloatingPointType alpha);
// element-wise multiplication
// z[i] = x[i] * y[i]
static void Mul(const KernelCtx& ctx, const int64_t n,
const FloatingPointType* x, const FloatingPointType* y,
FloatingPointType* z);
// level 2 matrix and vector
// matrix vector multiply
......
#include "oneflow/core/kernel/softmax_kernel.h"
#include "oneflow/core/kernel/kernel_manager.h"
namespace oneflow {
template<DeviceType device_type, typename FloatingPointType>
void SoftmaxKernel<device_type, FloatingPointType>::Forward(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2BlobPtr) const {
Blob* in_blob = BnInOp2BlobPtr(op()->SoleIbn());
Blob* out_blob = BnInOp2BlobPtr(op()->SoleObn());
Blob* tmp_blob = BnInOp2BlobPtr(op()->SoleDtbn());
const int64_t n = out_blob->shape().At(0);
const int64_t w = out_blob->shape().At(1);
const FloatingPointType* in =
static_cast<const FloatingPointType*>(in_blob->mut_dptr());
FloatingPointType* tmp =
static_cast<FloatingPointType*>(tmp_blob->mut_dptr());
FloatingPointType* out =
static_cast<FloatingPointType*>(out_blob->mut_dptr());
// copy in blob to out blob
KernelUtil<device_type, FloatingPointType>::BlasCopy(ctx, n * w, in, 1, out,
1);
// max | calculate max of every sample vector out[i], store in tmp[i]
// the out[i] now is store the data of in[i]
SoftmaxKernelUtil<device_type, FloatingPointType>::ForwardMax(ctx, n, w, out,
tmp);
// sub | every element of out blob subract the max value of the same sample
for (int64_t i = 0; i < w; ++i) {
KernelUtil<device_type, FloatingPointType>::BlasAxpy(ctx, n, -1.0, tmp, 1,
out + i, w);
}
// exp | exponentiation every element
KernelUtil<device_type, FloatingPointType>::Exp(ctx, n * w, out, out);
// sum | calculate sum of every sample vector out[i], store in tmp[i]
// the out[i] now is store the tmp data after exp
SoftmaxKernelUtil<device_type, FloatingPointType>::ForwardSum(ctx, n, w, out,
tmp);
// div | every element of out[i] divided by the data of tmp[i] (the sum value)
for (int64_t i = 0; i < n; ++i) {
KernelUtil<device_type, FloatingPointType>::Div(ctx, w, out + i * w,
tmp[i]);
}
}
template<DeviceType device_type, typename FloatingPointType>
void SoftmaxKernel<device_type, FloatingPointType>::Backward(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2BlobPtr) const {
Blob* out_blob = BnInOp2BlobPtr(op()->SoleObn());
Blob* out_diff_blob = BnInOp2BlobPtr(op()->SoleOdbn());
Blob* in_diff_blob = BnInOp2BlobPtr(op()->SoleIdbn());
Blob* tmp_blob = BnInOp2BlobPtr(op()->SoleDtbn());
const int64_t n = out_blob->shape().At(0);
const int64_t w = out_blob->shape().At(1);
FloatingPointType* in_diff =
static_cast<FloatingPointType*>(in_diff_blob->mut_dptr());
FloatingPointType* tmp =
static_cast<FloatingPointType*>(tmp_blob->mut_dptr());
const FloatingPointType* out =
static_cast<const FloatingPointType*>(out_blob->mut_dptr());
const FloatingPointType* out_diff =
static_cast<const FloatingPointType*>(out_diff_blob->mut_dptr());
// copy out_diff to in_diff
KernelUtil<device_type, FloatingPointType>::BlasCopy(ctx, n * w, out_diff, 1,
in_diff, 1);
// dot product | get dot product tmp[i] from out[i] * out_diff[i]
for (int64_t i = 0; i < n; ++i) {
KernelUtil<device_type, FloatingPointType>::BlasDot(
ctx, w, out + i * w, 1, out_diff + i * w, 1, tmp + i);
}
// sub | in_diff[i][j] -= tmp[i]
for (int64_t i = 0; i < w; ++i) {
KernelUtil<device_type, FloatingPointType>::BlasAxpy(ctx, n, -1.0, tmp, 1,
in_diff + i, w);
}
// elementwise multiplication | in_diff[i][j] *= out[i][j]
KernelUtil<device_type, FloatingPointType>::Mul(ctx, n * w, in_diff, out,
in_diff);
}
template<typename FloatingPointType>
class SoftmaxKernelUtil<DeviceType::kCPU, FloatingPointType> final {
public:
OF_DISALLOW_COPY_AND_MOVE(SoftmaxKernelUtil);
SoftmaxKernelUtil() = delete;
static void ForwardMax(const KernelCtx& ctx, const int64_t n, const int64_t w,
const FloatingPointType* out, FloatingPointType* tmp) {
for (int64_t i = 0; i < n; ++i) {
KernelUtil<DeviceType::kCPU, FloatingPointType>::Max(ctx, w, out + i * w,
tmp + i);
}
}
static void ForwardSum(const KernelCtx& ctx, const int64_t n, const int64_t w,
const FloatingPointType* out, FloatingPointType* tmp) {
for (int64_t i = 0; i < n; ++i) {
KernelUtil<DeviceType::kCPU, FloatingPointType>::Sum(ctx, w, out + i * w,
tmp + i);
}
}
};
INSTANTIATE_CPU_KERNEL_UTIL_CLASS(SoftmaxKernelUtil);
INSTANTIATE_KERNEL_CLASS(SoftmaxKernel);
REGISTER_KERNEL(OperatorConf::kSoftmaxConf, SoftmaxKernel);
} // namespace oneflow
#include "oneflow/core/kernel/kernel_manager.h"
#include "oneflow/core/kernel/softmax_kernel.h"
namespace oneflow {
namespace {
template<typename FloatingPointType>
__global__ void SoftmaxForwardMaxGpu(const int64_t n, const int64_t w,
const FloatingPointType* out,
FloatingPointType* tmp) {
CUDA_1D_KERNEL_LOOP(i, n) {
FloatingPointType max_value = out[i * w];
for (int64_t j = 0; j < w; ++j) {
max_value = max(max_value, out[i * w + j]);
}
tmp[i] = max_value;
}
}
template<typename FloatingPointType>
__global__ void SoftmaxForwardSumGpu(const int64_t n, const int64_t w,
const FloatingPointType* out,
FloatingPointType* tmp) {
CUDA_1D_KERNEL_LOOP(i, n) {
FloatingPointType sum_value = 0;
for (int64_t j = 0; j < w; ++j) { sum_value += out[i * w + j]; }
tmp[i] = sum_value;
}
}
} // namespace
template<typename FloatingPointType>
class SoftmaxKernelUtil<DeviceType::kGPU, FloatingPointType> final {
public:
OF_DISALLOW_COPY_AND_MOVE(SoftmaxKernelUtil);
SoftmaxKernelUtil() = delete;
static void ForwardMax(const KernelCtx& ctx, const int64_t n, const int64_t w,
const FloatingPointType* out, FloatingPointType* tmp) {
SoftmaxForwardMaxGpu<FloatingPointType>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(n, w, out, tmp);
}
static void ForwardSum(const KernelCtx& ctx, const int64_t n, const int64_t w,
const FloatingPointType* out, FloatingPointType* tmp) {
SoftmaxForwardSumGpu<FloatingPointType>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(n, w, out, tmp);
}
};
INSTANTIATE_GPU_KERNEL_UTIL_CLASS(SoftmaxKernelUtil);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_SOFTMAX_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_SOFTMAX_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename FloatingPointType>
class SoftmaxKernel final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(SoftmaxKernel);
SoftmaxKernel() = default;
~SoftmaxKernel() = default;
void Forward(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
void Backward(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
template<DeviceType device_type, typename FloatingPointType>
class SoftmaxKernelUtil final {
public:
OF_DISALLOW_COPY_AND_MOVE(SoftmaxKernelUtil);
SoftmaxKernelUtil() = delete;
// n = number of data sample
// w = number of (input/output) neuron
static void ForwardMax(const KernelCtx& ctx, const int64_t n, const int64_t w,
const FloatingPointType* out, FloatingPointType* tmp);
static void ForwardSum(const KernelCtx& ctx, const int64_t n, const int64_t w,
const FloatingPointType* out, FloatingPointType* tmp);
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_SOFTMAX_KERNEL_H_
......@@ -8,7 +8,7 @@ void SoftmaxOp::InitFromOpConf(const OperatorConf& op_conf) {
EnrollInputBn("in");
EnrollOutputBn("out");
EnrollDataTmpBn("tmp_max");
EnrollDataTmpBn("tmp");
}
const PbMessage& SoftmaxOp::GetSpecialConf() const {
......
......@@ -20,9 +20,9 @@ TEST(SoftmaxOp, softmax_3x5) {
softmax_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1);
// test
Shape* output_shape_ptr = fp(softmax_op->SoleObn());
Shape* tmp_max_shape_ptr = fp(softmax_op->SoleDtbn());
Shape* tmp_shape_ptr = fp(softmax_op->SoleDtbn());
ASSERT_EQ(*output_shape_ptr, Shape({3, 5}));
ASSERT_EQ(*tmp_max_shape_ptr, Shape({3}));
ASSERT_EQ(*tmp_shape_ptr, Shape({3}));
}
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册