提交 b709af61 编写于 作者: L Luo Tao

HuberTwoClassification only support one dimension

上级 e63ad0a6
......@@ -672,10 +672,10 @@ void HuberTwoClassification::forwardImp(Matrix& output,
Matrix& target) {
HuberCost::forwardImp(output, label, target);
size_t numSamples = target.getHeight();
size_t dim = output.getWidth();
CHECK(label.ids);
CHECK_EQ((*label.ids).getSize(), numSamples);
CHECK_EQ(output.getHeight(), numSamples);
CHECK_EQ(output.getWidth(), (size_t)1);
CHECK_EQ(target.getWidth(), (size_t)1);
real* out = useGpu_ ? tmpCpuInput_[0].value->getData() : output.getData();
......@@ -683,14 +683,11 @@ void HuberTwoClassification::forwardImp(Matrix& output,
std::vector<real> cost(numSamples, 0);
for (size_t i = 0; i < numSamples; ++i) {
int y = 2 * lbl[i] - 1;
for (size_t j = 0; j < dim; ++j) {
int index = i * dim + j;
real a = out[index] * y;
real a = out[i] * y;
if (a < -1)
cost[i] += -4 * a;
cost[i] = -4 * a;
else if (a < 1)
cost[i] += (1 - a) * (1 - a);
}
cost[i] = (1 - a) * (1 - a);
}
target.copyFrom(cost.data(), numSamples);
}
......@@ -699,22 +696,18 @@ void HuberTwoClassification::backwardImp(Matrix& output,
Argument& label,
Matrix& outputG) {
size_t numSamples = output.getHeight();
size_t dim = output.getWidth();
real* out = useGpu_ ? tmpCpuInput_[0].value->getData() : output.getData();
int* lbl = useGpu_ ? tmpCpuInput_[1].ids->getData() : (*label.ids).getData();
real* grad = useGpu_ ? tmpCpuInput_[0].grad->getData() : outputG.getData();
for (size_t i = 0; i < numSamples; ++i) {
int y = 2 * lbl[i] - 1;
for (size_t j = 0; j < dim; ++j) {
int index = i * dim + j;
real a = out[index] * y;
real a = out[i] * y;
if (a < -1)
grad[index] += -4 * y;
grad[i] += -4 * y;
else if (a < 1)
grad[index] += -2 * (1 - a) * y;
grad[i] += -2 * (1 - a) * y;
}
}
if (useGpu_) outputG.copyFrom(grad, numSamples * dim);
if (useGpu_) outputG.copyFrom(grad, numSamples);
}
/**
* This cost layer compute the sum of its input as loss.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册