提交 af9a3301 编写于 作者: J JiabinYang

test=develop

上级 014e50c2
...@@ -121,7 +121,9 @@ class SelectedRows { ...@@ -121,7 +121,9 @@ class SelectedRows {
int64_t AutoGrownIndex(int64_t key, bool auto_grown); int64_t AutoGrownIndex(int64_t key, bool auto_grown);
void SyncIndex(); void SyncIndex();
/*
* @brief Get complete Dims before
*/
DDim GetCompleteDims() const { DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims()); std::vector<int64_t> dims = vectorize(value_->dims());
dims[0] = height_; dims[0] = height_;
...@@ -136,7 +138,7 @@ class SelectedRows { ...@@ -136,7 +138,7 @@ class SelectedRows {
std::unordered_map<int64_t, int64_t> std::unordered_map<int64_t, int64_t>
id_to_index_; // should not be used when ids has duplicate member id_to_index_; // should not be used when ids has duplicate member
std::unique_ptr<Tensor> value_{nullptr}; std::unique_ptr<Tensor> value_{nullptr};
int64_t height_; int64_t height_; // height indicates the underline tensor's height
std::unique_ptr<RWLock> rwlock_{nullptr}; std::unique_ptr<RWLock> rwlock_{nullptr};
}; };
......
...@@ -145,8 +145,9 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -145,8 +145,9 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("PreOut"), PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null."); "Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
"Output(W@Grad should not be null.)"); "Output(W@Grad should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad should not be null.");
if (ctx->HasOutput(framework::GradVarName("Bias"))) { if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias")); ctx->GetInputDim("Bias"));
......
...@@ -191,10 +191,10 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -191,10 +191,10 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
framework::Vector<int64_t> real_rows = cal_rows(path); framework::Vector<int64_t> real_rows = cal_rows(path);
auto* w_grad = auto* w_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("W")); ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
w_grad->set_rows(real_rows); w_grad->set_rows(real_rows);
// build ids -> rows index map // build ids -> rows index map
w_grad->SyncIndex(); w_grad->SyncIndex();
w_grad->set_height(w->dims()[0]);
auto* w_grad_value = w_grad->mutable_value(); auto* w_grad_value = w_grad->mutable_value();
framework::DDim temp_dim(w->dims()); framework::DDim temp_dim(w->dims());
set(temp_dim, 0, real_rows.size()); set(temp_dim, 0, real_rows.size());
......
...@@ -140,148 +140,167 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): ...@@ -140,148 +140,167 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
return pre_output, out return pre_output, out
# class TestHSigmoidOp(OpTest): class TestHSigmoidOp(OpTest):
# def setUp(self): def setUp(self):
# self.op_type = "hierarchical_sigmoid" self.op_type = "hierarchical_sigmoid"
# num_classes = 6 num_classes = 6
# feature_size = 8 feature_size = 8
# batch_size = 4 batch_size = 4
# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 x = np.random.random((batch_size, feature_size)).astype("float32") * 2
# w = np.random.random( w = np.random.random(
# (num_classes - 1, feature_size)).astype("float32") * 2 (num_classes - 1, feature_size)).astype("float32") * 2
# label = np.random.randint(0, num_classes, (batch_size, 1)) label = np.random.randint(0, num_classes, (batch_size, 1))
# bias = np.random.random((1, num_classes - 1)).astype("float32") bias = np.random.random((1, num_classes - 1)).astype("float32")
# self.attrs = {'num_classes': num_classes, 'is_sparse': False} self.attrs = {'num_classes': num_classes, 'is_sparse': False}
# self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
# pre_output, out = hsigmoid(x, w, label, bias, num_classes) pre_output, out = hsigmoid(x, w, label, bias, num_classes)
# self.outputs = {'PreOut': pre_output, 'Out': out} self.outputs = {'PreOut': pre_output, 'Out': out}
# def test_check_output(self): def test_check_output(self):
# self.check_output() self.check_output()
# def test_check_grad(self): def test_check_grad(self):
# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
# class TestHSigmoidOpSparse(OpTest):
# def setUp(self): class TestHSigmoidOpSparse(OpTest):
# self.op_type = "hierarchical_sigmoid" def setUp(self):
# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample self.op_type = "hierarchical_sigmoid"
# feature_size = 8 num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
# batch_size = 4 feature_size = 8
# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 batch_size = 4
# w = np.random.random( x = np.random.random((batch_size, feature_size)).astype("float32")
# (num_classes - 1, feature_size)).astype("float32") * 2 w = np.random.random((num_classes - 1, feature_size)).astype("float32")
# label = np.array([0, 1, 4, 5]) label = np.array([0, 1, 4, 5])
# ptable = np.array( ptable = np.array(
# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
# (0, 2, -1, -1, (0, 2, -1, -1,
# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
# bias = np.random.random((1, num_classes - 1)).astype("float32") bias = np.random.random((1, num_classes - 1)).astype("float32")
# self.attrs = {'num_classes': num_classes, 'is_sparse': True} self.attrs = {'num_classes': num_classes, 'is_sparse': True}
# self.inputs = { self.inputs = {
# 'X': x, 'X': x,
# 'W': w, 'W': w,
# 'PTable': ptable, 'PTable': ptable,
# 'PCode': pcode, 'PCode': pcode,
# 'Label': label, 'Label': label,
# 'Bias': bias 'Bias': bias
# } }
# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
# bias, num_classes) bias, num_classes)
# self.outputs = {'PreOut': pre_output, 'Out': out} self.outputs = {'PreOut': pre_output, 'Out': out}
# def test_check_output(self): def test_check_output(self):
# print("checking output in CostumTree") print("checking output in CostumTree")
# self.check_output() self.check_output()
class TestHSigmoidOpWithSparseGrad(): class TestHSigmoidOpWithSparseGrad(unittest.TestCase):
def hs_net_conf(self): def hs_net_conf(self, is_sparse):
emb = fluid.layers.data(name="x", shape=[3], dtype='int64') input_word = fluid.layers.data(name="x", shape=[1], dtype='int64')
ptable = fluid.layers.data(name='ptable', shape=[3], dtype='int64') ptable = fluid.layers.data(name='ptable', shape=[3], dtype='int64')
pcode = fluid.layers.data(name='pcode', shape=[3], dtype='int64') pcode = fluid.layers.data(name='pcode', shape=[3], dtype='int64')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
data_list = [emb, ptable, pcode, label]
data_list = [input_word, ptable, pcode, label]
emb = fluid.layers.embedding(
input=input_word,
is_sparse=False,
size=[3, 3],
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(3))))
cost = fluid.layers.hsigmoid( cost = fluid.layers.hsigmoid(
input=emb, input=emb,
label=predict_word, label=label,
non_leaf_num=4, non_leaf_num=3,
ptable=ptable, ptable=ptable,
pcode=pcode, pcode=pcode,
is_costum=True, is_costum=True,
is_sparse=True) is_sparse=is_sparse)
avg_cost = fluid.layers.reduce_mean(cost) avg_cost = fluid.layers.reduce_mean(cost)
return avg_cost, data_list return avg_cost, data_list
def test_training_test(self): def training_test(self, is_sparse):
print("im here") with fluid.program_guard(fluid.Program(), fluid.Program()):
w = np.arange(12).reshape(4, 3) start_up = fluid.default_startup_program()
x = np.ones((2, 3)) start_up.random_seed = 1 # Fix random seed
ptable = np.array([(1, 2, -1), (1, 2, -1)]) x = np.arange(6).reshape(6)
pcode = np.array([(1, 0, -1), (0, 0, -1)]) ptable = np.array([(1, 2, -1), (1, 2, -1)])
label = np.array([(1, 4)]) pcode = np.array([(1, 0, -1), (0, 0, -1)])
label = np.array([1, 4])
loss, data_list = hs_net_conf()
optimizer = fluid.optimizer.SGD(learning_rate=1e-3) loss, data_list = self.hs_net_conf(is_sparse)
optimizer.minimize(loss) optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
main_program = fluid.default_main_program()
main_program = fluid.default_main_program()
place = fluid.CPUPlace() # print("main program: {program}".format{program=str(main_program)})
feeder = fluid.DataFeeder(feed_list=data_list, place=place) place = fluid.CPUPlace()
data_name_list = [var.name for var in data_list] feeder = fluid.DataFeeder(feed_list=data_list, place=place)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for pass_id in range(args.num_passes): exe.run(start_up)
result = list()
for i in range(10): for i in range(10):
data = [w, x[i % 2], ptable[i % 2], pcode[i % 2], label[i % 2]] data = [([[x[i % 2]]], [list(ptable[i % 2])],
[list(pcode[i % 2])], [label[i % 2]])]
loss_val = exe.run(main_program, loss_val = exe.run(main_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
print("loss is: {loss}".format(loss=loss)) result.append(loss_val)
return result
# class TestHSigmoidOpWithCostumTree(OpTest): def test_hs_grad_with_sparse(self):
# def setUp(self): dense_result = self.training_test(is_sparse=False)
# self.op_type = "hierarchical_sigmoid" sparse_result = self.training_test(is_sparse=True)
# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample assert (dense_result == sparse_result)
# feature_size = 8
# batch_size = 4
# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 class TestHSigmoidOpWithCostumTree(OpTest):
# w = np.random.random( def setUp(self):
# (num_classes - 1, feature_size)).astype("float32") * 2 self.op_type = "hierarchical_sigmoid"
# label = np.array([0, 1, 4, 5]) num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample
# ptable = np.array( feature_size = 8
# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), batch_size = 4
# (0, 2, -1, -1, x = np.random.random((batch_size, feature_size)).astype("float32") * 2
# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) w = np.random.random(
# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( (num_classes - 1, feature_size)).astype("float32") * 2
# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store label = np.array([0, 1, 4, 5])
# bias = np.random.random((1, num_classes - 1)).astype("float32") ptable = np.array(
# self.attrs = {'num_classes': num_classes, 'is_sparse': False} [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
# self.inputs = { (0, 2, -1, -1,
# 'X': x, -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
# 'W': w, pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
# 'PTable': ptable, 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
# 'PCode': pcode, bias = np.random.random((1, num_classes - 1)).astype("float32")
# 'Label': label, self.attrs = {'num_classes': num_classes, 'is_sparse': False}
# 'Bias': bias self.inputs = {
# } 'X': x,
# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, 'W': w,
# bias, num_classes) 'PTable': ptable,
# self.outputs = {'PreOut': pre_output, 'Out': out} 'PCode': pcode,
'Label': label,
# def test_check_output(self): 'Bias': bias
# print("checking output in CostumTree") }
# self.check_output() pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label,
bias, num_classes)
# def test_check_grad(self): self.outputs = {'PreOut': pre_output, 'Out': out}
# print("checking outputGrad in CostumTree")
# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) def test_check_output(self):
print("checking output in CostumTree")
self.check_output()
def test_check_grad(self):
print("checking outputGrad in CostumTree")
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册