From af9a3301dab9ab291d3cdd278734ae129de8a0f0 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Wed, 21 Nov 2018 12:35:21 +0000 Subject: [PATCH] test=develop --- paddle/fluid/framework/selected_rows.h | 6 +- .../operators/hierarchical_sigmoid_op.cc | 5 +- .../fluid/operators/hierarchical_sigmoid_op.h | 2 +- .../fluid/tests/unittests/test_hsigmoid_op.py | 269 ++++++++++-------- 4 files changed, 152 insertions(+), 130 deletions(-) diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index 4d728ae54ae..9d87c3eac7f 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -121,7 +121,9 @@ class SelectedRows { int64_t AutoGrownIndex(int64_t key, bool auto_grown); void SyncIndex(); - + /* + * @brief Get complete Dims before + */ DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); dims[0] = height_; @@ -136,7 +138,7 @@ class SelectedRows { std::unordered_map id_to_index_; // should not be used when ids has duplicate member std::unique_ptr value_{nullptr}; - int64_t height_; + int64_t height_; // height indicates the underline tensor's height std::unique_ptr rwlock_{nullptr}; }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index b2f46164415..c350e6489dd 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -145,8 +145,9 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("PreOut"), "Input(Preout) should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), - "Output(W@Grad should not be null.)"); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); + "Output(W@Grad should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad should not be null."); if (ctx->HasOutput(framework::GradVarName("Bias"))) { ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 3e2fbafa266..35a1de3e191 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -191,10 +191,10 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { framework::Vector real_rows = cal_rows(path); auto* w_grad = ctx.Output(framework::GradVarName("W")); - w_grad->set_rows(real_rows); // build ids -> rows index map w_grad->SyncIndex(); + w_grad->set_height(w->dims()[0]); auto* w_grad_value = w_grad->mutable_value(); framework::DDim temp_dim(w->dims()); set(temp_dim, 0, real_rows.size()); diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index 50dfaee76fd..2f4225f912d 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -140,148 +140,167 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): return pre_output, out -# class TestHSigmoidOp(OpTest): -# def setUp(self): -# self.op_type = "hierarchical_sigmoid" -# num_classes = 6 -# feature_size = 8 -# batch_size = 4 -# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 -# w = np.random.random( -# (num_classes - 1, feature_size)).astype("float32") * 2 -# label = np.random.randint(0, num_classes, (batch_size, 1)) -# bias = np.random.random((1, num_classes - 1)).astype("float32") -# self.attrs = {'num_classes': num_classes, 'is_sparse': False} -# self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} -# pre_output, out = hsigmoid(x, w, label, bias, num_classes) -# self.outputs = {'PreOut': pre_output, 'Out': out} - -# def test_check_output(self): -# self.check_output() - -# def test_check_grad(self): -# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) - -# class TestHSigmoidOpSparse(OpTest): -# def setUp(self): -# self.op_type = "hierarchical_sigmoid" -# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample -# feature_size = 8 -# batch_size = 4 -# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 -# w = np.random.random( -# (num_classes - 1, feature_size)).astype("float32") * 2 -# label = np.array([0, 1, 4, 5]) -# ptable = np.array( -# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), -# (0, 2, -1, -1, -# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) -# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( -# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store -# bias = np.random.random((1, num_classes - 1)).astype("float32") -# self.attrs = {'num_classes': num_classes, 'is_sparse': True} -# self.inputs = { -# 'X': x, -# 'W': w, -# 'PTable': ptable, -# 'PCode': pcode, -# 'Label': label, -# 'Bias': bias -# } -# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, -# bias, num_classes) -# self.outputs = {'PreOut': pre_output, 'Out': out} - -# def test_check_output(self): -# print("checking output in CostumTree") -# self.check_output() - - -class TestHSigmoidOpWithSparseGrad(): - def hs_net_conf(self): - emb = fluid.layers.data(name="x", shape=[3], dtype='int64') +class TestHSigmoidOp(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 + label = np.random.randint(0, num_classes, (batch_size, 1)) + bias = np.random.random((1, num_classes - 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} + self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} + pre_output, out = hsigmoid(x, w, label, bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + + +class TestHSigmoidOpSparse(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") + w = np.random.random((num_classes - 1, feature_size)).astype("float32") + label = np.array([0, 1, 4, 5]) + ptable = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + bias = np.random.random((1, num_classes - 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': True} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': ptable, + 'PCode': pcode, + 'Label': label, + 'Bias': bias + } + pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, + bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + print("checking output in CostumTree") + self.check_output() + + +class TestHSigmoidOpWithSparseGrad(unittest.TestCase): + def hs_net_conf(self, is_sparse): + input_word = fluid.layers.data(name="x", shape=[1], dtype='int64') ptable = fluid.layers.data(name='ptable', shape=[3], dtype='int64') pcode = fluid.layers.data(name='pcode', shape=[3], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64') - data_list = [emb, ptable, pcode, label] + + data_list = [input_word, ptable, pcode, label] + + emb = fluid.layers.embedding( + input=input_word, + is_sparse=False, + size=[3, 3], + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(3)))) + cost = fluid.layers.hsigmoid( input=emb, - label=predict_word, - non_leaf_num=4, + label=label, + non_leaf_num=3, ptable=ptable, pcode=pcode, is_costum=True, - is_sparse=True) + is_sparse=is_sparse) avg_cost = fluid.layers.reduce_mean(cost) return avg_cost, data_list - def test_training_test(self): - print("im here") - w = np.arange(12).reshape(4, 3) - x = np.ones((2, 3)) - ptable = np.array([(1, 2, -1), (1, 2, -1)]) - pcode = np.array([(1, 0, -1), (0, 0, -1)]) - label = np.array([(1, 4)]) - - loss, data_list = hs_net_conf() - optimizer = fluid.optimizer.SGD(learning_rate=1e-3) - optimizer.minimize(loss) - - main_program = fluid.default_main_program() - - place = fluid.CPUPlace() - feeder = fluid.DataFeeder(feed_list=data_list, place=place) - data_name_list = [var.name for var in data_list] - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - for pass_id in range(args.num_passes): + def training_test(self, is_sparse): + with fluid.program_guard(fluid.Program(), fluid.Program()): + start_up = fluid.default_startup_program() + start_up.random_seed = 1 # Fix random seed + x = np.arange(6).reshape(6) + ptable = np.array([(1, 2, -1), (1, 2, -1)]) + pcode = np.array([(1, 0, -1), (0, 0, -1)]) + label = np.array([1, 4]) + + loss, data_list = self.hs_net_conf(is_sparse) + optimizer = fluid.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + + main_program = fluid.default_main_program() + # print("main program: {program}".format{program=str(main_program)}) + place = fluid.CPUPlace() + feeder = fluid.DataFeeder(feed_list=data_list, place=place) + exe = fluid.Executor(place) + + exe.run(start_up) + result = list() for i in range(10): - data = [w, x[i % 2], ptable[i % 2], pcode[i % 2], label[i % 2]] + data = [([[x[i % 2]]], [list(ptable[i % 2])], + [list(pcode[i % 2])], [label[i % 2]])] + loss_val = exe.run(main_program, feed=feeder.feed(data), fetch_list=[loss]) - print("loss is: {loss}".format(loss=loss)) - - -# class TestHSigmoidOpWithCostumTree(OpTest): -# def setUp(self): -# self.op_type = "hierarchical_sigmoid" -# num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample -# feature_size = 8 -# batch_size = 4 -# x = np.random.random((batch_size, feature_size)).astype("float32") * 2 -# w = np.random.random( -# (num_classes - 1, feature_size)).astype("float32") * 2 -# label = np.array([0, 1, 4, 5]) -# ptable = np.array( -# [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), -# (0, 2, -1, -1, -# -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) -# pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( -# 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store -# bias = np.random.random((1, num_classes - 1)).astype("float32") -# self.attrs = {'num_classes': num_classes, 'is_sparse': False} -# self.inputs = { -# 'X': x, -# 'W': w, -# 'PTable': ptable, -# 'PCode': pcode, -# 'Label': label, -# 'Bias': bias -# } -# pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, -# bias, num_classes) -# self.outputs = {'PreOut': pre_output, 'Out': out} - -# def test_check_output(self): -# print("checking output in CostumTree") -# self.check_output() - -# def test_check_grad(self): -# print("checking outputGrad in CostumTree") -# self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + result.append(loss_val) + return result + + def test_hs_grad_with_sparse(self): + dense_result = self.training_test(is_sparse=False) + sparse_result = self.training_test(is_sparse=True) + assert (dense_result == sparse_result) + + +class TestHSigmoidOpWithCostumTree(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid" + num_classes = 6 #using 1,2,3,4,5,6 to build a huffman tree and select 1,2,5,6 as sample + feature_size = 8 + batch_size = 4 + x = np.random.random((batch_size, feature_size)).astype("float32") * 2 + w = np.random.random( + (num_classes - 1, feature_size)).astype("float32") * 2 + label = np.array([0, 1, 4, 5]) + ptable = np.array( + [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), + (0, 2, -1, -1, + -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) + pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( + 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store + bias = np.random.random((1, num_classes - 1)).astype("float32") + self.attrs = {'num_classes': num_classes, 'is_sparse': False} + self.inputs = { + 'X': x, + 'W': w, + 'PTable': ptable, + 'PCode': pcode, + 'Label': label, + 'Bias': bias + } + pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, + bias, num_classes) + self.outputs = {'PreOut': pre_output, 'Out': out} + + def test_check_output(self): + print("checking output in CostumTree") + self.check_output() + + def test_check_grad(self): + print("checking outputGrad in CostumTree") + self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) + if __name__ == '__main__': unittest.main() -- GitLab