提交 32e05b01 编写于 作者: J JiabinYang

test=develop

上级 c8801e10
...@@ -86,6 +86,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -86,6 +86,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
trans(ctx.template device_context<DeviceContext>(), pre_out_data, trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data, pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0))); ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
pre_out_mat = -1 * pre_out_mat;
bit_code->Sum(*pre_out, out, static_cast<T>(-1)); bit_code->Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy // use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log(); pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
...@@ -146,6 +147,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -146,6 +147,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
auto pre_out_mat = EigenMatrix<T>::From(*pre_out); auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad); auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
auto out_grad_mat = EigenMatrix<T>::From(*out_grad); auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}}); Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
// softrelu derivative // softrelu derivative
...@@ -160,9 +162,16 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -160,9 +162,16 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
bias_grad->mutable_data<T>(ctx.GetPlace()); bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0)); zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad); bit_code->AddGrad(pre_out_grad, bias_grad);
auto bias_grad_mat = EigenMatrix<T>::From(*bias_grad);
bias_grad_mat = -1 * bias_grad_mat;
} }
bit_code->MulGradWeight(pre_out_grad, w_grad, *in); bit_code->MulGradWeight(pre_out_grad, w_grad, *in);
bit_code->MulGradError(pre_out_grad, *w, in_grad); bit_code->MulGradError(pre_out_grad, *w, in_grad);
auto w_grad_mat = EigenMatrix<T>::From(*w_grad);
auto in_grad_mat = EigenMatrix<T>::From(*in_grad);
w_grad_mat = -1 * w_grad_mat;
in_grad_mat = -1 * in_grad_mat;
} }
}; };
......
...@@ -157,7 +157,7 @@ class CustomCode : public Code { ...@@ -157,7 +157,7 @@ class CustomCode : public Code {
int get_length() const { int get_length() const {
int length = 0; int length = 0;
for (int i = 0; i < ptable_->dims()[1]; i++) { for (int i = 0; i < static_cast<int>(ptable_->dims()[1]); i++) {
if (ptable_->data<R>()[index_ * static_cast<int>(ptable_->dims()[1]) + if (ptable_->data<R>()[index_ * static_cast<int>(ptable_->dims()[1]) +
i] != -1) { i] != -1) {
length++; length++;
......
...@@ -138,11 +138,8 @@ class OpTest(unittest.TestCase): ...@@ -138,11 +138,8 @@ class OpTest(unittest.TestCase):
cls.dtype = "float32" cls.dtype = "float32"
cls.outputs = {} cls.outputs = {}
# np.random.seed(123) np.random.seed(123)
# random.seed(124) random.seed(124)
np.random.seed(190)
random.seed(200)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
......
...@@ -17,6 +17,9 @@ from __future__ import print_function ...@@ -17,6 +17,9 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import math import math
# import paddle.fluid as fluid
# import paddle.fluid.core as core
# from op_builder import OpBuilder
from op_test import OpTest from op_test import OpTest
np.random.seed(100) np.random.seed(100)
...@@ -51,7 +54,7 @@ class CodeTableWithCustomTree(object): ...@@ -51,7 +54,7 @@ class CodeTableWithCustomTree(object):
def get_length(self): def get_length(self):
length = 0 length = 0
for ele in self.ptable_[self.index_]: for ele in self.ptable_[self.index_]: # find the first -1 to stop trace
if ele >= 0: if ele >= 0:
length = length + 1 length = length + 1
...@@ -71,12 +74,10 @@ def hsigmoid(x, w, label, bias, num_classes): ...@@ -71,12 +74,10 @@ def hsigmoid(x, w, label, bias, num_classes):
pre_sum = np.zeros((batch_size, 1)) pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32") out = np.zeros((batch_size, 1)).astype("float32")
for i in range(batch_size): for i in range(batch_size):
#print("\n leaf {leaf}: \n".format(leaf = label[i]))
code_table = CodeTable(num_classes, label[i]) code_table = CodeTable(num_classes, label[i])
length = code_table.get_length() length = code_table.get_length()
for j in range(length): for j in range(length):
idx = code_table.cal_index(j) idx = code_table.cal_index(j)
#print("index {index} ".format(index = j))
pre_output[i][j] += bias[0][idx] pre_output[i][j] += bias[0][idx]
for i in range(batch_size): for i in range(batch_size):
code_table = CodeTable(num_classes, label[i]) code_table = CodeTable(num_classes, label[i])
...@@ -87,13 +88,12 @@ def hsigmoid(x, w, label, bias, num_classes): ...@@ -87,13 +88,12 @@ def hsigmoid(x, w, label, bias, num_classes):
# clip[-40.0, 40.0] # clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0) pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j) # out(i, 0) = \sum_j bit(i, j) * preout(i, j)
pre_output = -1 * pre_output
for i in range(batch_size): for i in range(batch_size):
#print("\n leaf {leaf}: \n".format(leaf = label[i]))
code_table = CodeTable(num_classes, label[i]) code_table = CodeTable(num_classes, label[i])
length = code_table.get_length() length = code_table.get_length()
sum = 0.0 sum = 0.0
for j in range(length): for j in range(length):
#print("bit {bit} ".format(bit = code_table.cal_bit(j)))
if code_table.cal_bit(j): if code_table.cal_bit(j):
sum += pre_output[i][j] sum += pre_output[i][j]
out[i] = -1.0 * sum out[i] = -1.0 * sum
...@@ -108,6 +108,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): ...@@ -108,6 +108,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
batch_size = x.shape[0] batch_size = x.shape[0]
code_length = len(ptable[0]) code_length = len(ptable[0])
code_table = [0 for _ in range(code_length)] code_table = [0 for _ in range(code_length)]
# init pre_out with shape [N, code_length]
pre_output = np.zeros((batch_size, code_length)) pre_output = np.zeros((batch_size, code_length))
pre_sum = np.zeros((batch_size, 1)) pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32") out = np.zeros((batch_size, 1)).astype("float32")
...@@ -125,6 +126,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): ...@@ -125,6 +126,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
pre_output[i][j] += np.dot(w[idx], x[i]) pre_output[i][j] += np.dot(w[idx], x[i])
# clip[-40.0, 40.0] # clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0) pre_output = np.clip(pre_output, -40.0, 40.0)
pre_output = -1 * pre_output
# out(i, 0) = \sum_j bit(i, j) * preout(i, j) # out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size): for i in range(batch_size):
code_table = CodeTableWithCustomTree(ptable, pcode, i) code_table = CodeTableWithCustomTree(ptable, pcode, i)
...@@ -141,26 +143,27 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): ...@@ -141,26 +143,27 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
return pre_output, out return pre_output, out
# class TestHSigmoidOp(OpTest): class TestHSigmoidOp(OpTest):
# def setUp(self): def setUp(self):
# self.op_type = "hierarchical_sigmoid" self.op_type = "hierarchical_sigmoid"
# num_classes = 6 num_classes = 6
# feature_size = 8 feature_size = 8
# batch_size = 7 batch_size = 4
# x = np.random.random((batch_size, feature_size)).astype("float32") x = np.random.random((batch_size, feature_size)).astype("float32") * 2
# w = np.random.random((num_classes - 1, feature_size)).astype("float32") w = np.random.random(
# label = np.random.randint(0, num_classes, (batch_size, 1)) (num_classes - 1, feature_size)).astype("float32") * 2
# bias = np.random.random((1, num_classes - 1)).astype("float32") label = np.random.randint(0, num_classes, (batch_size, 1))
# self.attrs = {'num_classes': num_classes} bias = np.random.random((1, num_classes - 1)).astype("float32")
# self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} self.attrs = {'num_classes': num_classes}
# pre_output, out = hsigmoid(x, w, label, bias, num_classes) self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
# self.outputs = {'PreOut': pre_output, 'Out': out} pre_output, out = hsigmoid(x, w, label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
# def test_check_output(self): def test_check_output(self):
# self.check_output() self.check_output()
# def test_check_grad(self): def test_check_grad(self):
# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
class TestHSigmoidOpWithCostumTree(OpTest): class TestHSigmoidOpWithCostumTree(OpTest):
...@@ -169,9 +172,9 @@ class TestHSigmoidOpWithCostumTree(OpTest): ...@@ -169,9 +172,9 @@ class TestHSigmoidOpWithCostumTree(OpTest):
num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
feature_size = 8 feature_size = 8
batch_size = 4 batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32") * 10 x = np.random.random((batch_size, feature_size)).astype("float32") * 2
w = np.random.random( w = np.random.random(
(num_classes - 1, feature_size)).astype("float32") * 10 (num_classes - 1, feature_size)).astype("float32") * 2
label = np.array([0, 1, 4, 5]) label = np.array([0, 1, 4, 5])
ptable = np.array( ptable = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册