diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index f50a38842a21c795c979f859e88a9b16c3e54bd8..481cd52ee3eb021a39f0030c82fd596aeea7e500 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -121,6 +121,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs= paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)) +paddle.fluid.layers.sampled_softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'num_samples', 'num_true', 'remove_accidental_hits', 'use_custom_samples', 'custom_samples', 'custom_probabilities', 'seed'], varargs=None, keywords=None, defaults=(1, True, False, None, None, 0)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) diff --git a/paddle/fluid/operators/sample_logits_op.cu b/paddle/fluid/operators/sample_logits_op.cu index fe95542fd8f5d620094bf99d0e197480855f17e1..eb55c14ff9c8bfc8cb7dfe379095ab19b06943d0 100644 --- a/paddle/fluid/operators/sample_logits_op.cu +++ b/paddle/fluid/operators/sample_logits_op.cu @@ -113,10 +113,9 @@ class SampleLogitsCUDAKernel : public framework::OpKernel { if (!FLAGS_debug_print) { return; } - VLOG(1) << "qxz print " << name; - VLOG(1) << name << "size = " << t.numel(); + VLOG(1) << name << " size = " << t.numel(); size_t size = t.numel(); - type* d = t.data(); + const type* d = t.data(); #ifdef PADDLE_WITH_CUDA std::vector vec; platform::DeviceContextPool::Instance().Get(t.place())->Wait(); @@ -126,7 +125,7 @@ class SampleLogitsCUDAKernel : public framework::OpKernel { d = vec.data(); } #endif - VLOG(1) << name << " data_ptr = " << static_cast(d); + VLOG(1) << name << " data_ptr = " << static_cast(d); std::string out; for (size_t i = 0; i < size; i++) { out += std::to_string(d[i]); diff --git a/python/paddle/fluid/tests/unittests/test_sample_logits.py b/python/paddle/fluid/tests/unittests/test_sample_logits.py index b36694f11fc167e1e0b7452e7ee7c3ee353ddb30..7419cc513bee1f4ce3695922cad6491a37a9b7b2 100644 --- a/python/paddle/fluid/tests/unittests/test_sample_logits.py +++ b/python/paddle/fluid/tests/unittests/test_sample_logits.py @@ -349,827 +349,16 @@ class TestSampleLogitsOpV3(OpTest): self.inputs = {'Logits': logits, 'Label': label} def set_data(self, num_classes, num_samples, seed, remove_accidental_hits): - self.fetched_samples = np.array([[ - 52, - 3, - 12, - 74, - 28, - 1, - 79, - 2, - 42, - 8, - 13, - 0, - 18, - 88, - 49, - 14, - 46, - 39, - 57, - 26, - 75, - 9, - 50, - 16, - 66, - 6, - 23, - 5, - 11, - 17, - 54, - 35, - 20, - 53, - 10, - 47, - 80, - 38, - 7, - 4, - 31, - 15, - 19, - 58, - 22, - 34, - 41, - 73, - 62, - 95, - 25, - 70, - 37, - 30, - 65, - 27, - 51, - 43, - 32, - 99, - 21, - 56, - 29, - 40, - 69, - 55, - 98, - 77, - 67, - 33, - 89, - 63, - 81, - 59, - 48, - 91, - 68, - 72, - 61, - 52, - 86, - ], [ - 2, - 3, - 12, - 74, - 28, - 1, - 79, - 2, - 42, - 8, - 13, - 0, - 18, - 88, - 49, - 14, - 46, - 39, - 57, - 26, - 75, - 9, - 50, - 16, - 66, - 6, - 23, - 5, - 11, - 17, - 54, - 35, - 20, - 53, - 10, - 47, - 80, - 38, - 7, - 4, - 31, - 15, - 19, - 58, - 22, - 34, - 41, - 73, - 62, - 95, - 25, - 70, - 37, - 30, - 65, - 27, - 51, - 43, - 32, - 99, - 21, - 56, - 29, - 40, - 69, - 55, - 98, - 77, - 67, - 33, - 89, - 63, - 81, - 59, - 48, - 91, - 68, - 72, - 61, - 52, - 86, - ], [ - 2, - 3, - 12, - 74, - 28, - 1, - 79, - 2, - 42, - 8, - 13, - 0, - 18, - 88, - 49, - 14, - 46, - 39, - 57, - 26, - 75, - 9, - 50, - 16, - 66, - 6, - 23, - 5, - 11, - 17, - 54, - 35, - 20, - 53, - 10, - 47, - 80, - 38, - 7, - 4, - 31, - 15, - 19, - 58, - 22, - 34, - 41, - 73, - 62, - 95, - 25, - 70, - 37, - 30, - 65, - 27, - 51, - 43, - 32, - 99, - 21, - 56, - 29, - 40, - 69, - 55, - 98, - 77, - 67, - 33, - 89, - 63, - 81, - 59, - 48, - 91, - 68, - 72, - 61, - 52, - 86, - ], [ - 17, - 3, - 12, - 74, - 28, - 1, - 79, - 2, - 42, - 8, - 13, - 0, - 18, - 88, - 49, - 14, - 46, - 39, - 57, - 26, - 75, - 9, - 50, - 16, - 66, - 6, - 23, - 5, - 11, - 17, - 54, - 35, - 20, - 53, - 10, - 47, - 80, - 38, - 7, - 4, - 31, - 15, - 19, - 58, - 22, - 34, - 41, - 73, - 62, - 95, - 25, - 70, - 37, - 30, - 65, - 27, - 51, - 43, - 32, - 99, - 21, - 56, - 29, - 40, - 69, - 55, - 98, - 77, - 67, - 33, - 89, - 63, - 81, - 59, - 48, - 91, - 68, - 72, - 61, - 52, - 86, - ], [ - 96, - 3, - 12, - 74, - 28, - 1, - 79, - 2, - 42, - 8, - 13, - 0, - 18, - 88, - 49, - 14, - 46, - 39, - 57, - 26, - 75, - 9, - 50, - 16, - 66, - 6, - 23, - 5, - 11, - 17, - 54, - 35, - 20, - 53, - 10, - 47, - 80, - 38, - 7, - 4, - 31, - 15, - 19, - 58, - 22, - 34, - 41, - 73, - 62, - 95, - 25, - 70, - 37, - 30, - 65, - 27, - 51, - 43, - 32, - 99, - 21, - 56, - 29, - 40, - 69, - 55, - 98, - 77, - 67, - 33, - 89, - 63, - 81, - 59, - 48, - 91, - 68, - 72, - 61, - 52, - 86, - ], [ - 2, - 3, - 12, - 74, - 28, - 1, - 79, - 2, - 42, - 8, - 13, - 0, - 18, - 88, - 49, - 14, - 46, - 39, - 57, - 26, - 75, - 9, - 50, - 16, - 66, - 6, - 23, - 5, - 11, - 17, - 54, - 35, - 20, - 53, - 10, - 47, - 80, - 38, - 7, - 4, - 31, - 15, - 19, - 58, - 22, - 34, - 41, - 73, - 62, - 95, - 25, - 70, - 37, - 30, - 65, - 27, - 51, - 43, - 32, - 99, - 21, - 56, - 29, - 40, - 69, - 55, - 98, - 77, - 67, - 33, - 89, - 63, - 81, - 59, - 48, - 91, - 68, - 72, - 61, - 52, - 86, - ], [ - 17, - 3, - 12, - 74, - 28, - 1, - 79, - 2, - 42, - 8, - 13, - 0, - 18, - 88, - 49, - 14, - 46, - 39, - 57, - 26, - 75, - 9, - 50, - 16, - 66, - 6, - 23, - 5, - 11, - 17, - 54, - 35, - 20, - 53, - 10, - 47, - 80, - 38, - 7, - 4, - 31, - 15, - 19, - 58, - 22, - 34, - 41, - 73, - 62, - 95, - 25, - 70, - 37, - 30, - 65, - 27, - 51, - 43, - 32, - 99, - 21, - 56, - 29, - 40, - 69, - 55, - 98, - 77, - 67, - 33, - 89, - 63, - 81, - 59, - 48, - 91, - 68, - 72, - 61, - 52, - 86, - ], [ - 96, - 3, - 12, - 74, - 28, - 1, - 79, - 2, - 42, - 8, - 13, - 0, - 18, - 88, - 49, - 14, - 46, - 39, - 57, - 26, - 75, - 9, - 50, - 16, - 66, - 6, - 23, - 5, - 11, - 17, - 54, - 35, - 20, - 53, - 10, - 47, - 80, - 38, - 7, - 4, - 31, - 15, - 19, - 58, - 22, - 34, - 41, - 73, - 62, - 95, - 25, - 70, - 37, - 30, - 65, - 27, - 51, - 43, - 32, - 99, - 21, - 56, - 29, - 40, - 69, - 55, - 98, - 77, - 67, - 33, - 89, - 63, - 81, - 59, - 48, - 91, - 68, - 72, - 61, - 52, - 86, - ], [ - 37, - 3, - 12, - 74, - 28, - 1, - 79, - 2, - 42, - 8, - 13, - 0, - 18, - 88, - 49, - 14, - 46, - 39, - 57, - 26, - 75, - 9, - 50, - 16, - 66, - 6, - 23, - 5, - 11, - 17, - 54, - 35, - 20, - 53, - 10, - 47, - 80, - 38, - 7, - 4, - 31, - 15, - 19, - 58, - 22, - 34, - 41, - 73, - 62, - 95, - 25, - 70, - 37, - 30, - 65, - 27, - 51, - 43, - 32, - 99, - 21, - 56, - 29, - 40, - 69, - 55, - 98, - 77, - 67, - 33, - 89, - 63, - 81, - 59, - 48, - 91, - 68, - 72, - 61, - 52, - 86, - ], [ - 2, - 3, - 12, - 74, - 28, - 1, - 79, - 2, - 42, - 8, - 13, - 0, - 18, - 88, - 49, - 14, - 46, - 39, - 57, - 26, - 75, - 9, - 50, - 16, - 66, - 6, - 23, - 5, - 11, - 17, - 54, - 35, - 20, - 53, - 10, - 47, - 80, - 38, - 7, - 4, - 31, - 15, - 19, - 58, - 22, - 34, - 41, - 73, - 62, - 95, - 25, - 70, - 37, - 30, - 65, - 27, - 51, - 43, - 32, - 99, - 21, - 56, - 29, - 40, - 69, - 55, - 98, - 77, - 67, - 33, - 89, - 63, - 81, - 59, - 48, - 91, - 68, - 72, - 61, - 52, - 86, - ]]) + label = [52, 2, 2, 17, 96, 2, 17, 96, 37, 2] + samples = [ + 3, 12, 74, 28, 1, 79, 2, 42, 8, 13, 0, 18, 88, 49, 14, 46, 39, 57, + 26, 75, 9, 50, 16, 66, 6, 23, 5, 11, 17, 54, 35, 20, 53, 10, 47, 80, + 38, 7, 4, 31, 15, 19, 58, 22, 34, 41, 73, 62, 95, 25, 70, 37, 30, + 65, 27, 51, 43, 32, 99, 21, 56, 29, 40, 69, 55, 98, 77, 67, 33, 89, + 63, 81, 59, 48, 91, 68, 72, 61, 52, 86 + ] + + self.fetched_samples = np.array([[x] + samples for x in label]) fectched_num_tries = 323 label = self.fetched_samples[:, 0:1]