提交 b8ff0972 编写于 作者: J JiabinYang

test=develop

上级 32e05b01
...@@ -86,7 +86,6 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -86,7 +86,6 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
trans(ctx.template device_context<DeviceContext>(), pre_out_data, trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data, pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0))); ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
pre_out_mat = -1 * pre_out_mat;
bit_code->Sum(*pre_out, out, static_cast<T>(-1)); bit_code->Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy // use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log(); pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
...@@ -162,16 +161,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -162,16 +161,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
bias_grad->mutable_data<T>(ctx.GetPlace()); bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0)); zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad); bit_code->AddGrad(pre_out_grad, bias_grad);
auto bias_grad_mat = EigenMatrix<T>::From(*bias_grad);
bias_grad_mat = -1 * bias_grad_mat;
} }
bit_code->MulGradWeight(pre_out_grad, w_grad, *in); bit_code->MulGradWeight(pre_out_grad, w_grad, *in);
bit_code->MulGradError(pre_out_grad, *w, in_grad); bit_code->MulGradError(pre_out_grad, *w, in_grad);
auto w_grad_mat = EigenMatrix<T>::From(*w_grad);
auto in_grad_mat = EigenMatrix<T>::From(*in_grad);
w_grad_mat = -1 * w_grad_mat;
in_grad_mat = -1 * in_grad_mat;
} }
}; };
......
...@@ -88,7 +88,6 @@ def hsigmoid(x, w, label, bias, num_classes): ...@@ -88,7 +88,6 @@ def hsigmoid(x, w, label, bias, num_classes):
# clip[-40.0, 40.0] # clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0) pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j) # out(i, 0) = \sum_j bit(i, j) * preout(i, j)
pre_output = -1 * pre_output
for i in range(batch_size): for i in range(batch_size):
code_table = CodeTable(num_classes, label[i]) code_table = CodeTable(num_classes, label[i])
length = code_table.get_length() length = code_table.get_length()
...@@ -126,7 +125,6 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): ...@@ -126,7 +125,6 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
pre_output[i][j] += np.dot(w[idx], x[i]) pre_output[i][j] += np.dot(w[idx], x[i])
# clip[-40.0, 40.0] # clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0) pre_output = np.clip(pre_output, -40.0, 40.0)
pre_output = -1 * pre_output
# out(i, 0) = \sum_j bit(i, j) * preout(i, j) # out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size): for i in range(batch_size):
code_table = CodeTableWithCustomTree(ptable, pcode, i) code_table = CodeTableWithCustomTree(ptable, pcode, i)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册