提交 87db37ed 编写于 作者: qq_22305325's avatar qq_22305325 提交者: Jinhui Yuan

Dev hinge loss (#1207)

* add hinge loss

* add hinge loss test

* hack hinge loss

* optimize hinge loss

* optimize hinge loss

* optimize hinge loss

* optimize hinge loss
上级 6ab4006f
......@@ -57,10 +57,6 @@ struct HingeLossKernelUtil<DeviceType::kCPU, PredType, LabelType> {
KernelUtil<DeviceType::kCPU, PredType>::Mul(ctx, piece_size * pre_dim, tmp_diff, tmp_diff,
tmp);
KernelUtil<DeviceType::kCPU, PredType>::RowSum(ctx, piece_size, pre_dim, tmp, loss);
/*for (int64_t i = 0; i < piece_size; ++i) {
KernelUtil<DeviceType::kCPU, PredType>::Dot(ctx, pre_dim, tmp_diff + i * pre_dim, 1,
tmp_diff + i * pre_dim, 1, loss + i);
}*/
break;
default: LOG(FATAL) << "Invalid norm method in " << op_conf.name();
}
......
......@@ -63,10 +63,6 @@ struct HingeLossKernelUtil<DeviceType::kGPU, PredType, LabelType> {
KernelUtil<DeviceType::kGPU, PredType>::RowSum(ctx, piece_size, pre_dim, tmp, loss,
tmp_storage,
sizeof(PredType) * piece_size * pre_dim);
/*for (int64_t i = 0; i < piece_size; ++i) {
KernelUtil<DeviceType::kGPU, PredType>::Dot(ctx, pre_dim, tmp_diff + i * pre_dim, 1,
tmp_diff + i * pre_dim, 1, loss + i);
}*/
break;
default: LOG(FATAL) << "Invalid norm method in " << op_conf.name();
}
......
......@@ -629,6 +629,7 @@ message MultiplyOpConf {
required string in_1 = 2;
required string out = 4;
}
enum Norm {
L1 = 1;
L2 = 2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册