提交 3a5a0c9a 编写于 作者: C chengtbf 提交者: GitHub

softmax kernel test (#203)

* softmax kernel test

* fix some bug in blase dot

* change implement of softmax sub

* fix bug in softmax kernel

* statitic_cast

* fix
上级 97134844
......@@ -62,7 +62,7 @@ class KernelTestCommon<DeviceType::kCPU, FloatingPointType> final {
size_t dptr_size = lhs->shape().elem_cnt();
for (size_t i = 0; i < dptr_size; ++i) {
ASSERT_FLOAT_EQ(dptr_lhs[i], dptr_rhs[i]);
ASSERT_NEAR(dptr_lhs[i], dptr_rhs[i], 0.0000001);
}
}
......
......@@ -100,9 +100,9 @@ class KernelUtil<DeviceType::kCPU, FloatingPointType> final {
}
static void Div(const KernelCtx& ctx, const int64_t n, FloatingPointType* x,
const FloatingPointType alpha) {
const FloatingPointType* alpha_ptr) {
ctx.device_ctx->cpu_stream()->SendWork([=]() {
for (int64_t i = 0; i < n; ++i) { x[i] = x[i] / alpha; }
for (int64_t i = 0; i < n; ++i) { x[i] = x[i] / (*alpha_ptr); }
});
}
......
......@@ -14,8 +14,8 @@ __global__ void ExpGpu(const int64_t n, const FloatingPointType* x,
template<typename FloatingPointType>
__global__ void DivGpu(const int64_t n, FloatingPointType* x,
const FloatingPointType alpha) {
CUDA_1D_KERNEL_LOOP(i, n) { x[i] = x[i] / alpha; }
const FloatingPointType* alpha_ptr) {
CUDA_1D_KERNEL_LOOP(i, n) { x[i] = x[i] / (*alpha_ptr); }
}
template<typename FloatingPointType>
......@@ -64,10 +64,10 @@ class KernelUtil<DeviceType::kGPU, FloatingPointType> final {
}
static void Div(const KernelCtx& ctx, const int64_t n, FloatingPointType* x,
const FloatingPointType alpha) {
const FloatingPointType* alpha_ptr) {
DivGpu<FloatingPointType>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(n, x, alpha);
ctx.device_ctx->cuda_stream()>>>(n, x, alpha_ptr);
}
static void Mul(const KernelCtx& ctx, const int64_t n,
......
......@@ -67,7 +67,7 @@ class KernelUtil final {
// x = x / a
static void Div(const KernelCtx& ctx, const int64_t n, FloatingPointType* x,
const FloatingPointType alpha);
const FloatingPointType* alpha_ptr);
// element-wise multiplication
// z[i] = x[i] * y[i]
......
......@@ -13,7 +13,7 @@ void SoftmaxKernel<device_type, FloatingPointType>::Forward(
const int64_t n = out_blob->shape().At(0);
const int64_t w = out_blob->shape().At(1);
const FloatingPointType* in =
static_cast<const FloatingPointType*>(in_blob->mut_dptr());
static_cast<const FloatingPointType*>(in_blob->dptr());
FloatingPointType* tmp =
static_cast<FloatingPointType*>(tmp_blob->mut_dptr());
FloatingPointType* out =
......@@ -26,10 +26,7 @@ void SoftmaxKernel<device_type, FloatingPointType>::Forward(
SoftmaxKernelUtil<device_type, FloatingPointType>::ForwardMax(ctx, n, w, out,
tmp);
// sub | every element of out blob subract the max value of the same sample
for (int64_t i = 0; i < w; ++i) {
KernelUtil<device_type, FloatingPointType>::BlasAxpy(ctx, n, -1.0, tmp, 1,
out + i, w);
}
SoftmaxKernelUtil<device_type, FloatingPointType>::Sub(ctx, n, w, out, tmp);
// exp | exponentiation every element
KernelUtil<device_type, FloatingPointType>::Exp(ctx, n * w, out, out);
// sum | calculate sum of every sample vector out[i], store in tmp[i]
......@@ -39,7 +36,7 @@ void SoftmaxKernel<device_type, FloatingPointType>::Forward(
// div | every element of out[i] divided by the data of tmp[i] (the sum value)
for (int64_t i = 0; i < n; ++i) {
KernelUtil<device_type, FloatingPointType>::Div(ctx, w, out + i * w,
tmp[i]);
tmp + i);
}
}
......@@ -58,22 +55,18 @@ void SoftmaxKernel<device_type, FloatingPointType>::Backward(
FloatingPointType* tmp =
static_cast<FloatingPointType*>(tmp_blob->mut_dptr());
const FloatingPointType* out =
static_cast<const FloatingPointType*>(out_blob->mut_dptr());
static_cast<const FloatingPointType*>(out_blob->dptr());
const FloatingPointType* out_diff =
static_cast<const FloatingPointType*>(out_diff_blob->mut_dptr());
static_cast<const FloatingPointType*>(out_diff_blob->dptr());
// copy out_diff to in_diff
KernelUtil<device_type, FloatingPointType>::BlasCopy(ctx, n * w, out_diff, 1,
in_diff, 1);
// dot product | get dot product tmp[i] from out[i] * out_diff[i]
for (int64_t i = 0; i < n; ++i) {
KernelUtil<device_type, FloatingPointType>::BlasDot(
ctx, w, out + i * w, 1, out_diff + i * w, 1, tmp + i);
}
SoftmaxKernelUtil<device_type, FloatingPointType>::BackwardDot(ctx, n, w, out,
out_diff, tmp);
// sub | in_diff[i][j] -= tmp[i]
for (int64_t i = 0; i < w; ++i) {
KernelUtil<device_type, FloatingPointType>::BlasAxpy(ctx, n, -1.0, tmp, 1,
in_diff + i, w);
}
SoftmaxKernelUtil<device_type, FloatingPointType>::Sub(ctx, n, w, in_diff,
tmp);
// elementwise multiplication | in_diff[i][j] *= out[i][j]
KernelUtil<device_type, FloatingPointType>::Mul(ctx, n * w, in_diff, out,
in_diff);
......@@ -100,6 +93,25 @@ class SoftmaxKernelUtil<DeviceType::kCPU, FloatingPointType> final {
tmp + i);
}
}
static void Sub(const KernelCtx& ctx, const int64_t n, const int64_t w,
FloatingPointType* matrix, const FloatingPointType* vector) {
for (int64_t i = 0; i < w; ++i) {
KernelUtil<DeviceType::kCPU, FloatingPointType>::BlasAxpy(
ctx, n, static_cast<FloatingPointType>(-1.0), vector, 1, matrix + i,
w);
}
}
static void BackwardDot(const KernelCtx& ctx, const int64_t n,
const int64_t w, const FloatingPointType* out,
const FloatingPointType* out_diff,
FloatingPointType* tmp) {
for (int64_t i = 0; i < n; ++i) {
KernelUtil<DeviceType::kCPU, FloatingPointType>::BlasDot(
ctx, w, out + i * w, 1, out_diff + i * w, 1, tmp + i);
}
}
};
INSTANTIATE_CPU_KERNEL_UTIL_CLASS(SoftmaxKernelUtil);
......
......@@ -12,7 +12,7 @@ __global__ void SoftmaxForwardMaxGpu(const int64_t n, const int64_t w,
CUDA_1D_KERNEL_LOOP(i, n) {
FloatingPointType max_value = out[i * w];
for (int64_t j = 0; j < w; ++j) {
max_value = max(max_value, out[i * w + j]);
max_value = max_value > out[i * w + j] ? max_value : out[i * w + j];
}
tmp[i] = max_value;
}
......@@ -29,6 +29,27 @@ __global__ void SoftmaxForwardSumGpu(const int64_t n, const int64_t w,
}
}
template<typename FloatingPointType>
__global__ void SoftmaxSubGpu(const int64_t n, const int64_t w,
FloatingPointType* matrix,
const FloatingPointType* vector) {
CUDA_1D_KERNEL_LOOP(i, n * w) { matrix[i] -= vector[i / w]; }
}
template<typename FloatingPointType>
__global__ void SoftmaxBackwardDotGpu(const int64_t n, const int64_t w,
const FloatingPointType* out,
const FloatingPointType* out_diff,
FloatingPointType* tmp) {
CUDA_1D_KERNEL_LOOP(i, n) {
FloatingPointType dot_result = 0;
for (int64_t j = 0; j < w; ++j) {
dot_result += out[i * w + j] * out_diff[i * w + j];
}
tmp[i] = dot_result;
}
}
} // namespace
template<typename FloatingPointType>
......@@ -50,6 +71,22 @@ class SoftmaxKernelUtil<DeviceType::kGPU, FloatingPointType> final {
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(n, w, out, tmp);
}
static void Sub(const KernelCtx& ctx, const int64_t n, const int64_t w,
FloatingPointType* matrix, const FloatingPointType* vector) {
SoftmaxSubGpu<FloatingPointType>
<<<BlocksNum4ThreadsNum(n * w), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(n, w, matrix, vector);
}
static void BackwardDot(const KernelCtx& ctx, const int64_t n,
const int64_t w, const FloatingPointType* out,
const FloatingPointType* out_diff,
FloatingPointType* tmp) {
SoftmaxBackwardDotGpu<FloatingPointType>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(n, w, out, out_diff, tmp);
}
};
INSTANTIATE_GPU_KERNEL_UTIL_CLASS(SoftmaxKernelUtil);
......
......@@ -31,6 +31,16 @@ class SoftmaxKernelUtil final {
static void ForwardSum(const KernelCtx& ctx, const int64_t n, const int64_t w,
const FloatingPointType* out, FloatingPointType* tmp);
// matrix[i][j] -= vector[i]
// matrix shape = n*w, vector shape = n
static void Sub(const KernelCtx& ctx, const int64_t n, const int64_t w,
FloatingPointType* matrix, const FloatingPointType* vector);
static void BackwardDot(const KernelCtx& ctx, const int64_t n,
const int64_t w, const FloatingPointType* out,
const FloatingPointType* out_diff,
FloatingPointType* tmp);
};
} // namespace oneflow
......
#include "oneflow/core/kernel/softmax_kernel.h"
#include "oneflow/core/device/cpu_device_context.h"
#include "oneflow/core/device/cuda_device_context.h"
#include "oneflow/core/kernel/kernel_test_common.h"
namespace oneflow {
namespace test {
namespace {
template<DeviceType device_type, typename FloatingPointType>
std::function<Blob*(const std::string&)> BuildBnInOp2BlobPtr() {
using KTCommon = KernelTestCommon<device_type, FloatingPointType>;
FloatingPointType in_mat[8] = {1, 2, 3, 4, 0, 0, 0, 0};
FloatingPointType out_diff_mat[8] = {0.2, 1, 2, 3, -4.0, 3.0, -2.0, 1.0};
FloatingPointType expected_out_mat[8] = {
0.0320586, 0.0871443, 0.2368828, 0.6439143, 0.25, 0.25, 0.25, 0.25};
FloatingPointType expected_in_diff_mat[8] = {
-0.0737048, -0.1306350, -0.1182198, 0.3225595,
-0.875, 0.875, -0.375, 0.375};
auto bn2blob_ptr = new HashMap<std::string, Blob*>;
(*bn2blob_ptr)["in"] = KTCommon::CreateBlobWithVector({2, 4}, in_mat);
(*bn2blob_ptr)["out"] = KTCommon::CreateBlobWithSameValue({2, 4}, 0.0);
(*bn2blob_ptr)["tmp"] = KTCommon::CreateBlobWithSameValue({2}, 0.0);
(*bn2blob_ptr)["in_diff"] = KTCommon::CreateBlobWithSameValue({2, 4}, 0.0);
(*bn2blob_ptr)["out_diff"] =
KTCommon::CreateBlobWithVector({2, 4}, out_diff_mat);
(*bn2blob_ptr)["expected_out"] =
KTCommon::CreateBlobWithVector({2, 4}, expected_out_mat);
(*bn2blob_ptr)["expected_in_diff"] =
KTCommon::CreateBlobWithVector({2, 4}, expected_in_diff_mat);
return [bn2blob_ptr](const std::string& bn) { return bn2blob_ptr->at(bn); };
}
template<DeviceType device_type, typename FloatingPointType>
Kernel* BuildSoftmaxKernel() {
OperatorConf op_conf;
op_conf.set_name("softmax_op_test");
SoftmaxOpConf* softmax_conf = op_conf.mutable_softmax_conf();
softmax_conf->set_in("softmax/in");
softmax_conf->set_out("softmax/out");
auto softmax_op = OpMgr::Singleton()->ConstructOp(op_conf);
OperatorProto op_proto;
softmax_op->ToProto(&op_proto);
auto softmax_kernel = new SoftmaxKernel<device_type, FloatingPointType>();
softmax_kernel->InitFromOpProto(op_proto);
return softmax_kernel;
}
template<DeviceType device_type, typename FloatingPointType>
void TestSoftmaxKernel() {
using KTCommon = KernelTestCommon<device_type, FloatingPointType>;
KernelCtx ctx;
KTCommon::BuildKernelCtx(&ctx);
auto BnInOp2BlobPtr = BuildBnInOp2BlobPtr<device_type, FloatingPointType>();
auto softmax_kernel = BuildSoftmaxKernel<device_type, FloatingPointType>();
softmax_kernel->Forward(ctx, BnInOp2BlobPtr);
softmax_kernel->Backward(ctx, BnInOp2BlobPtr);
KTCommon::SyncStream(&ctx);
KTCommon::CheckResult(BnInOp2BlobPtr, "out", "expected_out");
KTCommon::CheckResult(BnInOp2BlobPtr, "in_diff", "expected_in_diff");
}
} // namespace
} // namespace test
TEST(SoftmaxKernel, softmax_kernel_cpu) {
test::TestSoftmaxKernel<DeviceType::kCPU, float>();
test::TestSoftmaxKernel<DeviceType::kCPU, double>();
}
TEST(SoftmaxKernel, softmax_kernel_gpu) {
test::TestSoftmaxKernel<DeviceType::kGPU, float>();
test::TestSoftmaxKernel<DeviceType::kGPU, double>();
}
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册