提交 4ee069fd 编写于 作者: G guosheng

Fix the HierarchicalSigmoidGradOpKernel and refine the codes. Now hsigmoid_op...

Fix the HierarchicalSigmoidGradOpKernel and refine the codes. Now hsigmoid_op is same with V2 implementation and can pass gradient check.
上级 e7f7ba97
...@@ -42,13 +42,13 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -42,13 +42,13 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
int64_t code_length = math::FindLastSet(num_classes - 1); int64_t code_length = math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0]; int64_t batch_size = in->dims()[0];
framework::Tensor sum; framework::Tensor sum;
math::SetConstant<DeviceContext, T> zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>( auto* pre_out_data = pre_out->mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
auto pre_out_mat = EigenMatrix<T>::From(*pre_out); auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
// Not all class(leaf) nodes' path lengths equal code_length, thus init as // Not all class(leaf) nodes' path lengths equal code_length, thus init as
// 0s can avoid out of path's loss. // 0s can avoid out of path's loss.
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, pre_out, static_cast<T>(0.0)); zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
math::RowwiseSum<DeviceContext, T> row_sum; math::RowwiseSum<DeviceContext, T> row_sum;
...@@ -72,6 +72,10 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -72,6 +72,10 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
// use softrelu to calculate cross entropy // use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log(); pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(dev_ctx, *pre_out, &sum); row_sum(dev_ctx, *pre_out, &sum);
// TODO(guosheng): Subtract the out of path's loss, since not all
// class(leaf) nodes' path lengths equal code_length. But it won't break the
// gradient check since both have the out of path's loss and will cancel out
// each other.
out_mat.device(place) = sum_mat + out_mat; out_mat.device(place) = sum_mat + out_mat;
} }
}; };
...@@ -90,33 +94,38 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -90,33 +94,38 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
auto* pre_out = ctx.Input<framework::Tensor>("PreOut"); auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
auto* out_grad = auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace());
w_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, in_grad, static_cast<T>(0.0));
zero(dev_ctx, w_grad, static_cast<T>(0.0));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t code_length = math::FindLastSet(num_classes - 1); math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
int64_t batch_size = in->dims()[0];
framework::Tensor pre_out_grad;
pre_out_grad.mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out); auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad); auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
auto out_grad_mat = EigenMatrix<T>::From(*out_grad); auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
pre_out_grad_mat = out_grad_mat.broadcast(bcast); Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
// softrelu derivative
pre_out_grad_mat.device(place) =
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
pre_out_grad_mat.device(place) = pre_out_grad_mat.device(place) =
pre_out_grad_mat * pre_out_grad_mat * out_grad_mat.broadcast(bcast);
(static_cast<T>(1.0) -
static_cast<T>(1.0) / pre_out_mat.exp()); // softrelu derivative
bit_code.Sub(&pre_out_grad);
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward. // be consistent with the clipping in forward.
if (bias_grad) { if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace()); bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code.AddGrad(pre_out_grad, bias_grad); bit_code.AddGrad(pre_out_grad, bias_grad);
} }
in_grad->mutable_data<T>(ctx.GetPlace());
w_grad->mutable_data<T>(ctx.GetPlace());
bit_code.MulGradWeight(pre_out_grad, w_grad, *in); bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
bit_code.MulGradError(pre_out_grad, *w, in_grad); bit_code.MulGradError(pre_out_grad, *w, in_grad);
} }
......
...@@ -62,6 +62,8 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat, ...@@ -62,6 +62,8 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
int code_length = code.get_length(); int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) { if (code.calc_bit(j)) {
// calc_bit starts from right most bit, while data in tmat[i] is in the
// reverse order.
sm += tmat.data<T>()[i * o_width + j]; sm += tmat.data<T>()[i * o_width + j];
} }
} }
......
...@@ -66,23 +66,20 @@ inline constexpr size_t FindLastSet(size_t x) { ...@@ -66,23 +66,20 @@ inline constexpr size_t FindLastSet(size_t x) {
struct SimpleCode { struct SimpleCode {
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
/** /**
* calc_index should make sure that all siblings have the same weight indice. * Here the id of root shoud be 1 rather than 0, thus the encoding of class c
* As for which weight index it maps to, it doesn't matter. To satisfy this, * is `c + num_classes` and all siblings can get the same weight indice using
* the id of root should be 1, and the left child of a node i is 2*i, the * prefixes.
* right child of a node i is 2*i+1. * Weight index is the prefixes of encoding, thus leave out the right most
* bit in calc_index.
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/ */
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
/**
* calc_bit uses the right most bits, while calc_index uses the left most
* bits. They are not the same, and that's why we say it doesn't matter which
* weight index calc_index maps to.
*/
inline bool calc_bit(int bit) const { return c_ & (1 << bit); } inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
inline int get_length() const { return FindLastSet(c_) - 1; } inline int get_length() const { return FindLastSet(c_) - 1; }
private: private:
size_t c_; // Here the id of root is 1 rather than 0, thus the id of class c size_t c_;
// is `c + num_classes`.
}; };
struct SimpleCodeTable { struct SimpleCodeTable {
......
...@@ -37,7 +37,6 @@ class CodeTable(object): ...@@ -37,7 +37,6 @@ class CodeTable(object):
def hsigmoid(x, w, label, bias, num_classes): def hsigmoid(x, w, label, bias, num_classes):
global pre_output
batch_size = x.shape[0] batch_size = x.shape[0]
code_length = find_latest_set(num_classes - 1) code_length = find_latest_set(num_classes - 1)
code_table = [0 for _ in range(code_length)] code_table = [0 for _ in range(code_length)]
...@@ -50,12 +49,12 @@ def hsigmoid(x, w, label, bias, num_classes): ...@@ -50,12 +49,12 @@ def hsigmoid(x, w, label, bias, num_classes):
for j in range(length): for j in range(length):
idx = code_table.cal_index(j) idx = code_table.cal_index(j)
pre_output[i][j] += bias[0][idx] pre_output[i][j] += bias[0][idx]
for j in range(batch_size): for i in range(batch_size):
code_table = CodeTable(num_classes, label[j]) code_table = CodeTable(num_classes, label[i])
length = code_table.get_length() length = code_table.get_length()
for k in range(length): for j in range(length):
idx = code_table.cal_index(k) idx = code_table.cal_index(j)
pre_output[j][k] = np.dot(w[idx], x[j]) pre_output[i][j] += np.dot(w[idx], x[i])
# clip[-40.0, 40.0] # clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0) pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j) # out(i, 0) = \sum_j bit(i, j) * preout(i, j)
...@@ -71,22 +70,22 @@ def hsigmoid(x, w, label, bias, num_classes): ...@@ -71,22 +70,22 @@ def hsigmoid(x, w, label, bias, num_classes):
pre_output = np.log(1 + np.exp(pre_output)) pre_output = np.log(1 + np.exp(pre_output))
pre_sum = pre_output.sum(1).reshape((batch_size, 1)) pre_sum = pre_output.sum(1).reshape((batch_size, 1))
out += pre_sum out += pre_sum
return out return pre_output, out
class TestHSigmoidOp(OpTest): class TestHSigmoidOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "hierarchical_sigmoid" self.op_type = "hierarchical_sigmoid"
num_classes = 6 num_classes = 6
feature_size = 5 feature_size = 8
batch_size = 4 batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32") x = np.random.random((batch_size, feature_size)).astype("float32")
w = np.random.random((num_classes - 1, feature_size)).astype("float32") w = np.random.random((num_classes - 1, feature_size)).astype("float32")
label = np.random.randint(0, num_classes, batch_size) label = np.random.randint(0, num_classes, (batch_size, 1))
bias = np.random.random((1, num_classes - 1)).astype("float32") bias = np.random.random((1, num_classes - 1)).astype("float32")
self.attrs = {'num_classes': num_classes} self.attrs = {'num_classes': num_classes}
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
out = hsigmoid(x, w, label, bias, num_classes) pre_output, out = hsigmoid(x, w, label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out} self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册