提交 f676d774 编写于 作者: qq_22305325's avatar qq_22305325 提交者: Jinhui Yuan

Dev hinge loss (#1190)

* add hinge loss

* add hinge loss test

* hack hinge loss

* optimize hinge loss

* optimize hinge loss

* optimize hinge loss

* optimize hinge loss


Former-commit-id: e2da4ecf
上级 a5f1e505
#include "oneflow/core/kernel/hinge_loss_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename PredType, typename LabelType>
void HingeLossKernel<device_type, PredType, LabelType>::VirtualLossForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* prediction_blob = BnInOp2Blob("prediction");
const Blob* label_blob = BnInOp2Blob("label");
Blob* loss_blob = BnInOp2Blob("loss");
Blob* tmp_diff_blob = BnInOp2Blob("tmp_diff");
Blob* tmp_blob = BnInOp2Blob("tmp");
Blob* tmp_storage_blob = BnInOp2Blob("tmp_storage");
const int64_t piece_size = prediction_blob->shape().At(0);
const int64_t pre_dim = prediction_blob->shape().Count(1);
const OperatorConf& op_conf = this->op_conf();
tmp_diff_blob->CopyDataContentFrom(ctx.device_ctx, prediction_blob);
// forward
HingeLossKernelUtil<device_type, PredType, LabelType>::Forward(
ctx.device_ctx, piece_size, pre_dim, prediction_blob->dptr<PredType>(),
label_blob->dptr<LabelType>(), op_conf, tmp_diff_blob->mut_dptr<PredType>(),
tmp_blob->mut_dptr<PredType>(), tmp_storage_blob->mut_dptr<PredType>(),
loss_blob->mut_dptr<PredType>());
// if predict_diff_blob is not null, then do backward
Blob* prediction_diff_blob = BnInOp2Blob(GenDiffBn("prediction"));
if (prediction_diff_blob != nullptr) {
HingeLossKernelUtil<device_type, PredType, LabelType>::Backward(
ctx.device_ctx, piece_size, pre_dim, tmp_diff_blob->mut_dptr<PredType>(),
label_blob->dptr<LabelType>(), op_conf, prediction_diff_blob->mut_dptr<PredType>());
}
}
template<DeviceType device_type, typename PredType, typename LabelType>
const LossKernelConf& HingeLossKernel<device_type, PredType, LabelType>::GetLossKernelConf(
const KernelConf& kernel_conf) const {
return kernel_conf.hinge_loss_conf().loss_conf();
}
template<typename PredType, typename LabelType>
struct HingeLossKernelUtil<DeviceType::kCPU, PredType, LabelType> {
static void Forward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim,
const PredType* pred, const LabelType* label, const OperatorConf& op_conf,
PredType* tmp_diff, PredType* tmp, PredType* tmp_storage, PredType* loss) {
// transfor sign of each pred according to label
for (int64_t i = 0; i < piece_size; ++i) {
tmp_diff[i * pre_dim + static_cast<int64_t>(label[i])] *= -1;
}
// compute diff of each dim
for (int64_t i = 0; i < piece_size * pre_dim; ++i) {
tmp_diff[i] = (1 + tmp_diff[i]) > 0 ? (1 + tmp_diff[i]) : 0;
}
switch (op_conf.hinge_loss_conf().norm()) {
case L1:
KernelUtil<DeviceType::kCPU, PredType>::RowSum(ctx, piece_size, pre_dim, tmp_diff, loss);
break;
case L2:
KernelUtil<DeviceType::kCPU, PredType>::Mul(ctx, piece_size * pre_dim, tmp_diff, tmp_diff,
tmp);
KernelUtil<DeviceType::kCPU, PredType>::RowSum(ctx, piece_size, pre_dim, tmp, loss);
/*for (int64_t i = 0; i < piece_size; ++i) {
KernelUtil<DeviceType::kCPU, PredType>::Dot(ctx, pre_dim, tmp_diff + i * pre_dim, 1,
tmp_diff + i * pre_dim, 1, loss + i);
}*/
break;
default: LOG(FATAL) << "Invalid norm method in " << op_conf.name();
}
}
static void Backward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim,
const PredType* tmp_diff, const LabelType* label,
const OperatorConf& op_conf, PredType* pred_diff) {
for (int64_t i = 0; i < piece_size * pre_dim; ++i) { pred_diff[i] = (tmp_diff[i] > 0); }
for (int64_t i = 0; i < piece_size; ++i) {
pred_diff[i * pre_dim + static_cast<int64_t>(label[i])] *= -1;
}
switch (op_conf.hinge_loss_conf().norm()) {
case L1: break;
case L2:
for (int64_t i = 0; i < piece_size * pre_dim; ++i) {
pred_diff[i] = 2 * tmp_diff[i] * pred_diff[i];
}
break;
default: LOG(FATAL) << "Invalid norm method in " << op_conf.name();
}
}
};
namespace {
Kernel* CreateHingeLossKernel(const KernelConf& kernel_conf) {
static const HashMap<std::string, std::function<Kernel*()>> creators = {
#define HINGE_LOSS_KERNEL_ENTRY(device_type, pred_type_pair, label_type_pair) \
{GetHashKey(device_type, OF_PP_PAIR_SECOND(pred_type_pair), OF_PP_PAIR_SECOND(label_type_pair)), \
[]() { \
return new HingeLossKernel<device_type, OF_PP_PAIR_FIRST(pred_type_pair), \
OF_PP_PAIR_FIRST(label_type_pair)>(); \
}},
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(HINGE_LOSS_KERNEL_ENTRY, DEVICE_TYPE_SEQ,
FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ)};
return creators.at(GetHashKey(kernel_conf.op_attribute().op_conf().device_type(),
kernel_conf.hinge_loss_conf().loss_conf().prediction_type(),
kernel_conf.hinge_loss_conf().loss_conf().label_type()))();
}
} // namespace
REGISTER_KERNEL_CREATOR(OperatorConf::kHingeLossConf, CreateHingeLossKernel);
#define MAKE_ENTRY(data_type_pair, label_type_pair) \
template struct HingeLossKernelUtil<DeviceType::kCPU, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(label_type_pair)>;
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ)
} // namespace oneflow
#include "oneflow/core/kernel/hinge_loss_kernel.h"
namespace oneflow {
namespace {
template<typename PredType, typename LabelType>
__global__ void HingeLossForwardTransSignGpu(const int64_t piece_size, const int64_t pre_dim,
const LabelType* label, PredType* tmp_diff) {
CUDA_1D_KERNEL_LOOP(i, piece_size) {
tmp_diff[i * pre_dim + static_cast<int64_t>(label[i])] *= -1;
}
}
template<typename PredType>
__global__ void HingeLossForwardMaxGpu(const int64_t piece_size, const int64_t pre_dim,
PredType* tmp_diff) {
CUDA_1D_KERNEL_LOOP(i, piece_size * pre_dim) {
tmp_diff[i] = 1 + tmp_diff[i] > 0 ? 1 + tmp_diff[i] : 0;
}
}
template<typename PredType>
__global__ void HingeLossBackwardTransSignGpu(const int64_t piece_size, const int64_t pre_dim,
const PredType* tmp_diff, PredType* pred_diff) {
CUDA_1D_KERNEL_LOOP(i, piece_size * pre_dim) { pred_diff[i] = (tmp_diff[i] > 0); }
}
template<typename PredType, typename LabelType>
__global__ void HingeLossBackwardL1Gpu(const int64_t piece_size, const int64_t pre_dim,
const LabelType* label, PredType* pred_diff) {
CUDA_1D_KERNEL_LOOP(i, piece_size) {
pred_diff[i * pre_dim + static_cast<int64_t>(label[i])] *= -1;
}
}
template<typename PredType>
__global__ void HingeLossBackwardL2Gpu(const int64_t piece_size, const int64_t pre_dim,
const PredType* tmp_diff, PredType* pred_diff) {
CUDA_1D_KERNEL_LOOP(i, piece_size * pre_dim) { pred_diff[i] = 2 * tmp_diff[i] * pred_diff[i]; }
}
} // namespace
template<typename PredType, typename LabelType>
struct HingeLossKernelUtil<DeviceType::kGPU, PredType, LabelType> {
static void Forward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim,
const PredType* pred, const LabelType* label, const OperatorConf& op_conf,
PredType* tmp_diff, PredType* tmp, PredType* tmp_storage, PredType* loss) {
HingeLossForwardTransSignGpu<<<BlocksNum4ThreadsNum(piece_size), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(piece_size, pre_dim, label, tmp_diff);
HingeLossForwardMaxGpu<<<BlocksNum4ThreadsNum(piece_size * pre_dim), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(piece_size, pre_dim, tmp_diff);
switch (op_conf.hinge_loss_conf().norm()) {
case L1:
KernelUtil<DeviceType::kGPU, PredType>::RowSum(ctx, piece_size, pre_dim, tmp_diff, loss,
tmp_storage,
sizeof(PredType) * piece_size * pre_dim);
break;
case L2:
KernelUtil<DeviceType::kGPU, PredType>::Mul(ctx, piece_size * pre_dim, tmp_diff, tmp_diff,
tmp);
KernelUtil<DeviceType::kGPU, PredType>::RowSum(ctx, piece_size, pre_dim, tmp, loss,
tmp_storage,
sizeof(PredType) * piece_size * pre_dim);
/*for (int64_t i = 0; i < piece_size; ++i) {
KernelUtil<DeviceType::kGPU, PredType>::Dot(ctx, pre_dim, tmp_diff + i * pre_dim, 1,
tmp_diff + i * pre_dim, 1, loss + i);
}*/
break;
default: LOG(FATAL) << "Invalid norm method in " << op_conf.name();
}
}
static void Backward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim,
const PredType* tmp_diff, const LabelType* label,
const OperatorConf& op_conf, PredType* pred_diff) {
HingeLossBackwardTransSignGpu<<<BlocksNum4ThreadsNum(piece_size * pre_dim),
kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
piece_size, pre_dim, tmp_diff, pred_diff);
HingeLossBackwardL1Gpu<<<BlocksNum4ThreadsNum(piece_size), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(piece_size, pre_dim, label, pred_diff);
switch (op_conf.hinge_loss_conf().norm()) {
case L1: break;
case L2:
HingeLossBackwardL2Gpu<<<BlocksNum4ThreadsNum(piece_size * pre_dim),
kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
piece_size, pre_dim, tmp_diff, pred_diff);
break;
default: LOG(FATAL) << "Invalid norm method in " << op_conf.name();
}
}
};
#define MAKE_ENTRY(data_type_pair, label_type_pair) \
template struct HingeLossKernelUtil<DeviceType::kGPU, OF_PP_PAIR_FIRST(data_type_pair), \
OF_PP_PAIR_FIRST(label_type_pair)>;
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_ENTRY, FLOATING_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ)
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_HINGE_LOSS_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_HINGE_LOSS_KERNEL_H_
#include "oneflow/core/kernel/loss_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename PredType, typename LabelType>
class HingeLossKernel final : public LossKernel<device_type, PredType, LabelType> {
public:
OF_DISALLOW_COPY_AND_MOVE(HingeLossKernel);
HingeLossKernel() = default;
~HingeLossKernel() = default;
private:
void VirtualLossForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
const LossKernelConf& GetLossKernelConf(const KernelConf& kernel_conf) const override;
};
template<DeviceType device_type, typename PredType, typename LabelType>
struct HingeLossKernelUtil {
static void Forward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim,
const PredType* pred, const LabelType* label, const OperatorConf& op_conf,
PredType* tmp_diff, PredType* tmp, PredType* tmp_storage, PredType* loss);
static void Backward(DeviceCtx* ctx, const int64_t piece_size, const int64_t pre_dim,
const PredType* tmp_diff, const LabelType* label,
const OperatorConf& op_conf, PredType* pred_diff);
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_HINGE_LOSS_KERNEL_H_
#include "oneflow/core/kernel/opkernel_test_case.h"
#include "oneflow/core/common/switch_func.h"
namespace oneflow {
namespace test {
template<DeviceType device_type, typename PredType>
struct HingeLossTestUtil final {
#define HINGE_LOSS_TEST_UTIL_ENTRY(func_name, T) \
HingeLossTestUtil<device_type, PredType>::template func_name<T>
DEFINE_STATIC_SWITCH_FUNC(void, Test, HINGE_LOSS_TEST_UTIL_ENTRY,
MAKE_STRINGIZED_DATA_TYPE_CTRV_SEQ(INT_DATA_TYPE_SEQ));
template<typename LabelType>
static void Test(OpKernelTestCase* test_case, const std::string& job_type,
const std::string& fw_or_bw) {
test_case->set_is_train(job_type == "train");
test_case->set_is_forward(fw_or_bw == "forward");
HingeLossOpConf* hinge_loss_conf = test_case->mut_op_conf()->mutable_hinge_loss_conf();
hinge_loss_conf->set_norm(Norm::L2);
hinge_loss_conf->set_label("test/label");
hinge_loss_conf->set_prediction("test/prediction");
hinge_loss_conf->set_loss("test/loss");
BlobDesc* label_blob_desc =
new BlobDesc(Shape({2}), GetDataType<LabelType>::value, false, false, 1);
BlobDesc* pred_blob_desc =
new BlobDesc(Shape({2, 5}), GetDataType<PredType>::value, false, false, 1);
BlobDesc* loss_blob_desc =
new BlobDesc(Shape({2}), GetDataType<PredType>::value, false, false, 1);
test_case->InitBlob<LabelType>("label", label_blob_desc, {2, 2});
test_case->InitBlob<PredType>(
"prediction", pred_blob_desc,
{-1.73, -1.24, 0.89, -0.99, 0.05, -1.73, -1.24, 0.89, -0.99, 0.05});
test_case->ForwardCheckBlob<PredType>("loss", loss_blob_desc, {1.1147, 1.1147});
test_case->BackwardCheckBlob<PredType>(
GenDiffBn("prediction"), pred_blob_desc,
{0.00, 0.00, -0.22, 0.02, 2.10, 0.00, 0.00, -0.22, 0.02, 2.10});
}
};
template<DeviceType device_type, typename PredType>
void HingeLossKernelTestCase(OpKernelTestCase* test_case, const std::string& label_type,
const std::string& job_type, const std::string& fw_or_bw) {
HingeLossTestUtil<device_type, PredType>::SwitchTest(SwitchCase(label_type), test_case, job_type,
fw_or_bw);
}
TEST_CPU_AND_GPU_OPKERNEL(HingeLossKernelTestCase, FLOATING_DATA_TYPE_SEQ,
OF_PP_SEQ_MAP(OF_PP_PAIR_FIRST, INT_DATA_TYPE_SEQ), (train)(predict),
(forward)(backward));
} // namespace test
} // namespace oneflow
......@@ -137,6 +137,10 @@ message OpAttribute {
repeated string const_buf_bns = 14;
}
message HingeLossKernelConf {
required LossKernelConf loss_conf = 1;
}
message KernelConf {
required OpAttribute op_attribute = 1;
optional bool need_do_data_id = 2 [default = false];
......@@ -152,6 +156,7 @@ message KernelConf {
SparseCrossEntropyLossKernelConf sparse_cross_entropy_loss_conf = 102;
DecodeRandomKernelConf decode_random_conf = 103;
DecodeOFRecordKernelConf decode_ofrecord_conf = 104;
HingeLossKernelConf hinge_loss_conf = 105;
SoftmaxKernelConf softmax_conf = 116;
TransposeKernelConf transpose_conf = 117;
ReduceSumKernelConf reduce_sum_conf = 120;
......
#include "oneflow/core/operator/hinge_loss_op.h"
namespace oneflow {
void HingeLossOp::VirtualInitFromOpConf() {
EnrollDataTmpBn("tmp");
EnrollDataTmpBn("tmp_diff");
EnrollDataTmpBn("tmp_storage"); // used by GPU
}
const PbMessage& HingeLossOp::GetCustomizedConf() const { return op_conf().hinge_loss_conf(); }
LossKernelConf* HingeLossOp::GetMutLossKernelConf(KernelConf* kernel_conf) const {
return kernel_conf->mutable_hinge_loss_conf()->mutable_loss_conf();
}
void HingeLossOp::VirtualInferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const BlobDesc* pred_blob_desc = GetBlobDesc4BnInOp("prediction");
#define OF_HINGE_LOSS_INFER_TMP_BLOB_DESC(blobname) \
BlobDesc* blobname##_blob_desc = GetBlobDesc4BnInOp(#blobname); \
blobname##_blob_desc->mut_shape() = Shape(pred_blob_desc->shape()); \
blobname##_blob_desc->set_data_type(pred_blob_desc->data_type())
OF_HINGE_LOSS_INFER_TMP_BLOB_DESC(tmp);
OF_HINGE_LOSS_INFER_TMP_BLOB_DESC(tmp_diff);
OF_HINGE_LOSS_INFER_TMP_BLOB_DESC(tmp_storage);
#undef OF_HINGE_LOSS_INFER_TMP_BLOB_DESC
}
REGISTER_OP(OperatorConf::kHingeLossConf, HingeLossOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_HINGE_LOSS_OP_H_
#define ONEFLOW_CORE_OPERATOR_HINGE_LOSS_OP_H_
#include "oneflow/core/operator/loss_op.h"
namespace oneflow {
class HingeLossOp final : public LossOp {
public:
OF_DISALLOW_COPY_AND_MOVE(HingeLossOp);
HingeLossOp() = default;
~HingeLossOp() = default;
const PbMessage& GetCustomizedConf() const override;
private:
void VirtualInitFromOpConf() override;
void VirtualInferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
LossKernelConf* GetMutLossKernelConf(KernelConf*) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_HINGE_LOSS_H_
......@@ -609,6 +609,21 @@ message AccuracyOpConf {
required string accuracy = 4;
}
enum Norm {
L1 = 1;
L2 = 2;
}
message HingeLossOpConf {
required string prediction = 1;
required string label = 2;
required string loss = 3;
optional LossReductionType reduction = 4 [default = kSumOverN];
optional float weight_scalar = 5 [default = 1.0];
optional string weight = 6;
optional Norm norm = 7[default = L1];
}
message OperatorConf {
required string name = 1;
optional string model_load_dir = 2;
......@@ -643,6 +658,7 @@ message OperatorConf {
Conv2DOpConf conv_2d_conf = 126;
Conv3DOpConf conv_3d_conf = 127;
TransposeOpConf transpose_conf = 128;
HingeLossOpConf hinge_loss_conf = 129;
DropoutOpConf dropout_conf = 140;
AveragePooling1DOpConf average_pooling_1d_conf = 200;
MaxPooling1DOpConf max_pooling_1d_conf = 201;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册