From f676d774b7cae0aeab43041fb74906866b6c2d4d Mon Sep 17 00:00:00 2001 From: binbinHan Date: Tue, 4 Sep 2018 14:32:29 +0800 Subject: [PATCH] Dev hinge loss (#1190) * add hinge loss * add hinge loss test * hack hinge loss * optimize hinge loss * optimize hinge loss * optimize hinge loss * optimize hinge loss Former-commit-id: e2da4ecff712850e35eb3dfb926d4eb7349765a8 --- oneflow/core/kernel/hinge_loss_kernel.cpp | 114 ++++++++++++++++++ oneflow/core/kernel/hinge_loss_kernel.cu | 99 +++++++++++++++ oneflow/core/kernel/hinge_loss_kernel.h | 33 +++++ .../core/kernel/hinge_loss_kernel_test.cpp | 56 +++++++++ oneflow/core/kernel/kernel.proto | 5 + oneflow/core/operator/hinge_loss_op.cpp | 33 +++++ oneflow/core/operator/hinge_loss_op.h | 25 ++++ oneflow/core/operator/op_conf.proto | 16 +++ 8 files changed, 381 insertions(+) create mode 100644 oneflow/core/kernel/hinge_loss_kernel.cpp create mode 100644 oneflow/core/kernel/hinge_loss_kernel.cu create mode 100644 oneflow/core/kernel/hinge_loss_kernel.h create mode 100644 oneflow/core/kernel/hinge_loss_kernel_test.cpp create mode 100644 oneflow/core/operator/hinge_loss_op.cpp create mode 100644 oneflow/core/operator/hinge_loss_op.h diff --git a/oneflow/core/kernel/hinge_loss_kernel.cpp b/oneflow/core/kernel/hinge_loss_kernel.cpp new file mode 100644 index 0000000000..af27c7190c --- /dev/null +++ b/oneflow/core/kernel/hinge_loss_kernel.cpp @@ -0,0 +1,114 @@ +#include "oneflow/core/kernel/hinge_loss_kernel.h" + +namespace oneflow { + +template +void HingeLossKernel::VirtualLossForwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + const Blob* prediction_blob = BnInOp2Blob("prediction"); + const Blob* label_blob = BnInOp2Blob("label"); + Blob* loss_blob = BnInOp2Blob("loss"); + Blob* tmp_diff_blob = BnInOp2Blob("tmp_diff"); + Blob* tmp_blob = BnInOp2Blob("tmp"); + Blob* tmp_storage_blob = BnInOp2Blob("tmp_storage"); + const int64_t piece_size = prediction_blob->shape().At(0); + const int64_t pre_dim = prediction_blob->shape().Count(1); + const OperatorConf& op_conf = this->op_conf(); + tmp_diff_blob->CopyDataContentFrom(ctx.device_ctx, prediction_blob); + // forward + HingeLossKernelUtil::Forward( + ctx.device_ctx, piece_size, pre_dim, prediction_blob->dptr(), + label_blob->dptr(), op_conf, tmp_diff_blob->mut_dptr(), + tmp_blob->mut_dptr(), tmp_storage_blob->mut_dptr(), + loss_blob->mut_dptr()); + // if predict_diff_blob is not null, then do backward + Blob* prediction_diff_blob = BnInOp2Blob(GenDiffBn("prediction")); + if (prediction_diff_blob != nullptr) { + HingeLossKernelUtil::Backward( + ctx.device_ctx, piece_size, pre_dim, tmp_diff_blob->mut_dptr(), + label_blob->dptr(), op_conf, prediction_diff_blob->mut_dptr()); + } +} + +template +const LossKernelConf& HingeLossKernel::GetLossKernelConf( + const KernelConf& kernel_conf) const { + return kernel_conf.hinge_loss_conf().loss_conf(); +} + +template +struct HingeLossKernelUtil { + static void Forward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim, + const PredType* pred, const LabelType* label, const OperatorConf& op_conf, + PredType* tmp_diff, PredType* tmp, PredType* tmp_storage, PredType* loss) { + // transfor sign of each pred according to label + for (int64_t i = 0; i < piece_size; ++i) { + tmp_diff[i * pre_dim + static_cast(label[i])] *= -1; + } + // compute diff of each dim + for (int64_t i = 0; i < piece_size * pre_dim; ++i) { + tmp_diff[i] = (1 + tmp_diff[i]) > 0 ? (1 + tmp_diff[i]) : 0; + } + switch (op_conf.hinge_loss_conf().norm()) { + case L1: + KernelUtil::RowSum(ctx, piece_size, pre_dim, tmp_diff, loss); + break; + case L2: + KernelUtil::Mul(ctx, piece_size * pre_dim, tmp_diff, tmp_diff, + tmp); + KernelUtil::RowSum(ctx, piece_size, pre_dim, tmp, loss); + /*for (int64_t i = 0; i < piece_size; ++i) { + KernelUtil::Dot(ctx, pre_dim, tmp_diff + i * pre_dim, 1, + tmp_diff + i * pre_dim, 1, loss + i); + }*/ + break; + default: LOG(FATAL) << "Invalid norm method in " << op_conf.name(); + } + } + + static void Backward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim, + const PredType* tmp_diff, const LabelType* label, + const OperatorConf& op_conf, PredType* pred_diff) { + for (int64_t i = 0; i < piece_size * pre_dim; ++i) { pred_diff[i] = (tmp_diff[i] > 0); } + for (int64_t i = 0; i < piece_size; ++i) { + pred_diff[i * pre_dim + static_cast(label[i])] *= -1; + } + switch (op_conf.hinge_loss_conf().norm()) { + case L1: break; + case L2: + for (int64_t i = 0; i < piece_size * pre_dim; ++i) { + pred_diff[i] = 2 * tmp_diff[i] * pred_diff[i]; + } + break; + default: LOG(FATAL) << "Invalid norm method in " << op_conf.name(); + } + } +}; + +namespace { + +Kernel* CreateHingeLossKernel(const KernelConf& kernel_conf) { + static const HashMap> creators = { +#define HINGE_LOSS_KERNEL_ENTRY(device_type, pred_type_pair, label_type_pair) \ + {GetHashKey(device_type, OF_PP_PAIR_SECOND(pred_type_pair), OF_PP_PAIR_SECOND(label_type_pair)), \ + []() { \ + return new HingeLossKernel(); \ + }}, + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(HINGE_LOSS_KERNEL_ENTRY, DEVICE_TYPE_SEQ, + FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ)}; + return creators.at(GetHashKey(kernel_conf.op_attribute().op_conf().device_type(), + kernel_conf.hinge_loss_conf().loss_conf().prediction_type(), + kernel_conf.hinge_loss_conf().loss_conf().label_type()))(); +} + +} // namespace + +REGISTER_KERNEL_CREATOR(OperatorConf::kHingeLossConf, CreateHingeLossKernel); + +#define MAKE_ENTRY(data_type_pair, label_type_pair) \ + template struct HingeLossKernelUtil; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ) + +} // namespace oneflow diff --git a/oneflow/core/kernel/hinge_loss_kernel.cu b/oneflow/core/kernel/hinge_loss_kernel.cu new file mode 100644 index 0000000000..74cb90b113 --- /dev/null +++ b/oneflow/core/kernel/hinge_loss_kernel.cu @@ -0,0 +1,99 @@ +#include "oneflow/core/kernel/hinge_loss_kernel.h" + +namespace oneflow { + +namespace { + +template +__global__ void HingeLossForwardTransSignGpu(const int64_t piece_size, const int64_t pre_dim, + const LabelType* label, PredType* tmp_diff) { + CUDA_1D_KERNEL_LOOP(i, piece_size) { + tmp_diff[i * pre_dim + static_cast(label[i])] *= -1; + } +} + +template +__global__ void HingeLossForwardMaxGpu(const int64_t piece_size, const int64_t pre_dim, + PredType* tmp_diff) { + CUDA_1D_KERNEL_LOOP(i, piece_size * pre_dim) { + tmp_diff[i] = 1 + tmp_diff[i] > 0 ? 1 + tmp_diff[i] : 0; + } +} + +template +__global__ void HingeLossBackwardTransSignGpu(const int64_t piece_size, const int64_t pre_dim, + const PredType* tmp_diff, PredType* pred_diff) { + CUDA_1D_KERNEL_LOOP(i, piece_size * pre_dim) { pred_diff[i] = (tmp_diff[i] > 0); } +} + +template +__global__ void HingeLossBackwardL1Gpu(const int64_t piece_size, const int64_t pre_dim, + const LabelType* label, PredType* pred_diff) { + CUDA_1D_KERNEL_LOOP(i, piece_size) { + pred_diff[i * pre_dim + static_cast(label[i])] *= -1; + } +} + +template +__global__ void HingeLossBackwardL2Gpu(const int64_t piece_size, const int64_t pre_dim, + const PredType* tmp_diff, PredType* pred_diff) { + CUDA_1D_KERNEL_LOOP(i, piece_size * pre_dim) { pred_diff[i] = 2 * tmp_diff[i] * pred_diff[i]; } +} + +} // namespace + +template +struct HingeLossKernelUtil { + static void Forward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim, + const PredType* pred, const LabelType* label, const OperatorConf& op_conf, + PredType* tmp_diff, PredType* tmp, PredType* tmp_storage, PredType* loss) { + HingeLossForwardTransSignGpu<<cuda_stream()>>>(piece_size, pre_dim, label, tmp_diff); + HingeLossForwardMaxGpu<<cuda_stream()>>>(piece_size, pre_dim, tmp_diff); + switch (op_conf.hinge_loss_conf().norm()) { + case L1: + KernelUtil::RowSum(ctx, piece_size, pre_dim, tmp_diff, loss, + tmp_storage, + sizeof(PredType) * piece_size * pre_dim); + break; + case L2: + KernelUtil::Mul(ctx, piece_size * pre_dim, tmp_diff, tmp_diff, + tmp); + KernelUtil::RowSum(ctx, piece_size, pre_dim, tmp, loss, + tmp_storage, + sizeof(PredType) * piece_size * pre_dim); + /*for (int64_t i = 0; i < piece_size; ++i) { + KernelUtil::Dot(ctx, pre_dim, tmp_diff + i * pre_dim, 1, + tmp_diff + i * pre_dim, 1, loss + i); + }*/ + break; + default: LOG(FATAL) << "Invalid norm method in " << op_conf.name(); + } + } + + static void Backward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim, + const PredType* tmp_diff, const LabelType* label, + const OperatorConf& op_conf, PredType* pred_diff) { + HingeLossBackwardTransSignGpu<<cuda_stream()>>>( + piece_size, pre_dim, tmp_diff, pred_diff); + HingeLossBackwardL1Gpu<<cuda_stream()>>>(piece_size, pre_dim, label, pred_diff); + switch (op_conf.hinge_loss_conf().norm()) { + case L1: break; + case L2: + HingeLossBackwardL2Gpu<<cuda_stream()>>>( + piece_size, pre_dim, tmp_diff, pred_diff); + break; + default: LOG(FATAL) << "Invalid norm method in " << op_conf.name(); + } + } +}; + +#define MAKE_ENTRY(data_type_pair, label_type_pair) \ + template struct HingeLossKernelUtil; +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ) +} // namespace oneflow diff --git a/oneflow/core/kernel/hinge_loss_kernel.h b/oneflow/core/kernel/hinge_loss_kernel.h new file mode 100644 index 0000000000..befa52b31f --- /dev/null +++ b/oneflow/core/kernel/hinge_loss_kernel.h @@ -0,0 +1,33 @@ +#ifndef ONEFLOW_CORE_KERNEL_HINGE_LOSS_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_HINGE_LOSS_KERNEL_H_ + +#include "oneflow/core/kernel/loss_kernel.h" + +namespace oneflow { + +template +class HingeLossKernel final : public LossKernel { + public: + OF_DISALLOW_COPY_AND_MOVE(HingeLossKernel); + HingeLossKernel() = default; + ~HingeLossKernel() = default; + + private: + void VirtualLossForwardDataContent(const KernelCtx&, + std::function) const override; + const LossKernelConf& GetLossKernelConf(const KernelConf& kernel_conf) const override; +}; + +template +struct HingeLossKernelUtil { + static void Forward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim, + const PredType* pred, const LabelType* label, const OperatorConf& op_conf, + PredType* tmp_diff, PredType* tmp, PredType* tmp_storage, PredType* loss); + static void Backward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim, + const PredType* tmp_diff, const LabelType* label, + const OperatorConf& op_conf, PredType* pred_diff); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_HINGE_LOSS_KERNEL_H_ diff --git a/oneflow/core/kernel/hinge_loss_kernel_test.cpp b/oneflow/core/kernel/hinge_loss_kernel_test.cpp new file mode 100644 index 0000000000..7c3165d784 --- /dev/null +++ b/oneflow/core/kernel/hinge_loss_kernel_test.cpp @@ -0,0 +1,56 @@ +#include "oneflow/core/kernel/opkernel_test_case.h" +#include "oneflow/core/common/switch_func.h" + +namespace oneflow { + +namespace test { + +template +struct HingeLossTestUtil final { +#define HINGE_LOSS_TEST_UTIL_ENTRY(func_name, T) \ + HingeLossTestUtil::template func_name + + DEFINE_STATIC_SWITCH_FUNC(void, Test, HINGE_LOSS_TEST_UTIL_ENTRY, + MAKE_STRINGIZED_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ)); + + template + static void Test(OpKernelTestCase* test_case, const std::string& job_type, + const std::string& fw_or_bw) { + test_case->set_is_train(job_type == "train"); + test_case->set_is_forward(fw_or_bw == "forward"); + HingeLossOpConf* hinge_loss_conf = test_case->mut_op_conf()->mutable_hinge_loss_conf(); + hinge_loss_conf->set_norm(Norm::L2); + hinge_loss_conf->set_label("test/label"); + hinge_loss_conf->set_prediction("test/prediction"); + hinge_loss_conf->set_loss("test/loss"); + BlobDesc* label_blob_desc = + new BlobDesc(Shape({2}), GetDataType::value, false, false, 1); + BlobDesc* pred_blob_desc = + new BlobDesc(Shape({2, 5}), GetDataType::value, false, false, 1); + BlobDesc* loss_blob_desc = + new BlobDesc(Shape({2}), GetDataType::value, false, false, 1); + test_case->InitBlob("label", label_blob_desc, {2, 2}); + test_case->InitBlob( + "prediction", pred_blob_desc, + {-1.73, -1.24, 0.89, -0.99, 0.05, -1.73, -1.24, 0.89, -0.99, 0.05}); + test_case->ForwardCheckBlob("loss", loss_blob_desc, {1.1147, 1.1147}); + test_case->BackwardCheckBlob( + GenDiffBn("prediction"), pred_blob_desc, + {0.00, 0.00, -0.22, 0.02, 2.10, 0.00, 0.00, -0.22, 0.02, 2.10}); + } +}; + +template +void HingeLossKernelTestCase(OpKernelTestCase* test_case, const std::string& label_type, + const std::string& job_type, const std::string& fw_or_bw) { + HingeLossTestUtil::SwitchTest(SwitchCase(label_type), test_case, job_type, + fw_or_bw); +} + +TEST_CPU_AND_GPU_OPKERNEL(HingeLossKernelTestCase, FLOATING_DATA_TYPE_SEQ, + OF_PP_SEQ_MAP(OF_PP_PAIR_FIRST, INT_DATA_TYPE_SEQ), (train)(predict), + (forward)(backward)); + +} // namespace test + +} // namespace oneflow diff --git a/oneflow/core/kernel/kernel.proto b/oneflow/core/kernel/kernel.proto index ffa374e85d..cdeeab3ddc 100644 --- a/oneflow/core/kernel/kernel.proto +++ b/oneflow/core/kernel/kernel.proto @@ -137,6 +137,10 @@ message OpAttribute { repeated string const_buf_bns = 14; } +message HingeLossKernelConf { + required LossKernelConf loss_conf = 1; +} + message KernelConf { required OpAttribute op_attribute = 1; optional bool need_do_data_id = 2 [default = false]; @@ -152,6 +156,7 @@ message KernelConf { SparseCrossEntropyLossKernelConf sparse_cross_entropy_loss_conf = 102; DecodeRandomKernelConf decode_random_conf = 103; DecodeOFRecordKernelConf decode_ofrecord_conf = 104; + HingeLossKernelConf hinge_loss_conf = 105; SoftmaxKernelConf softmax_conf = 116; TransposeKernelConf transpose_conf = 117; ReduceSumKernelConf reduce_sum_conf = 120; diff --git a/oneflow/core/operator/hinge_loss_op.cpp b/oneflow/core/operator/hinge_loss_op.cpp new file mode 100644 index 0000000000..39a6c022aa --- /dev/null +++ b/oneflow/core/operator/hinge_loss_op.cpp @@ -0,0 +1,33 @@ +#include "oneflow/core/operator/hinge_loss_op.h" + +namespace oneflow { + +void HingeLossOp::VirtualInitFromOpConf() { + EnrollDataTmpBn("tmp"); + EnrollDataTmpBn("tmp_diff"); + EnrollDataTmpBn("tmp_storage"); // used by GPU +} + +const PbMessage& HingeLossOp::GetCustomizedConf() const { return op_conf().hinge_loss_conf(); } + +LossKernelConf* HingeLossOp::GetMutLossKernelConf(KernelConf* kernel_conf) const { + return kernel_conf->mutable_hinge_loss_conf()->mutable_loss_conf(); +} + +void HingeLossOp::VirtualInferBlobDescs( + std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction"); +#define OF_HINGE_LOSS_INFER_TMP_BLOB_DESC(blobname) \ + BlobDesc* blobname##_blob_desc = GetBlobDesc4BnInOp(#blobname); \ + blobname##_blob_desc->mut_shape() = Shape(pred_blob_desc->shape()); \ + blobname##_blob_desc->set_data_type(pred_blob_desc->data_type()) + OF_HINGE_LOSS_INFER_TMP_BLOB_DESC(tmp); + OF_HINGE_LOSS_INFER_TMP_BLOB_DESC(tmp_diff); + OF_HINGE_LOSS_INFER_TMP_BLOB_DESC(tmp_storage); +#undef OF_HINGE_LOSS_INFER_TMP_BLOB_DESC +} + +REGISTER_OP(OperatorConf::kHingeLossConf, HingeLossOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/hinge_loss_op.h b/oneflow/core/operator/hinge_loss_op.h new file mode 100644 index 0000000000..527754fd6c --- /dev/null +++ b/oneflow/core/operator/hinge_loss_op.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_OPERATOR_HINGE_LOSS_OP_H_ +#define ONEFLOW_CORE_OPERATOR_HINGE_LOSS_OP_H_ + +#include "oneflow/core/operator/loss_op.h" + +namespace oneflow { + +class HingeLossOp final : public LossOp { + public: + OF_DISALLOW_COPY_AND_MOVE(HingeLossOp); + HingeLossOp() = default; + ~HingeLossOp() = default; + + const PbMessage& GetCustomizedConf() const override; + + private: + void VirtualInitFromOpConf() override; + void VirtualInferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + LossKernelConf* GetMutLossKernelConf(KernelConf*) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_HINGE_LOSS_H_ diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 44cf836ede..97202ec9bf 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -609,6 +609,21 @@ message AccuracyOpConf { required string accuracy = 4; } +enum Norm { + L1 = 1; + L2 = 2; +} + +message HingeLossOpConf { + required string prediction = 1; + required string label = 2; + required string loss = 3; + optional LossReductionType reduction = 4 [default = kSumOverN]; + optional float weight_scalar = 5 [default = 1.0]; + optional string weight = 6; + optional Norm norm = 7[default = L1]; +} + message OperatorConf { required string name = 1; optional string model_load_dir = 2; @@ -643,6 +658,7 @@ message OperatorConf { Conv2DOpConf conv_2d_conf = 126; Conv3DOpConf conv_3d_conf = 127; TransposeOpConf transpose_conf = 128; + HingeLossOpConf hinge_loss_conf = 129; DropoutOpConf dropout_conf = 140; AveragePooling1DOpConf average_pooling_1d_conf = 200; MaxPooling1DOpConf max_pooling_1d_conf = 201; -- GitLab