提交 4921c2cd 编写于 作者: X xuezhong

add api spec change

test=develop
上级 fb261793
...@@ -121,6 +121,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs= ...@@ -121,6 +121,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
paddle.fluid.layers.sampled_softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'num_samples', 'num_true', 'remove_accidental_hits', 'use_custom_samples', 'custom_samples', 'custom_probabilities', 'seed'], varargs=None, keywords=None, defaults=(1, True, False, None, None, 0))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
......
...@@ -113,10 +113,9 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> { ...@@ -113,10 +113,9 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
if (!FLAGS_debug_print) { if (!FLAGS_debug_print) {
return; return;
} }
VLOG(1) << "qxz print " << name; VLOG(1) << name << " size = " << t.numel();
VLOG(1) << name << "size = " << t.numel();
size_t size = t.numel(); size_t size = t.numel();
type* d = t.data<type>(); const type* d = t.data<type>();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::vector<type> vec; std::vector<type> vec;
platform::DeviceContextPool::Instance().Get(t.place())->Wait(); platform::DeviceContextPool::Instance().Get(t.place())->Wait();
...@@ -126,7 +125,7 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> { ...@@ -126,7 +125,7 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
d = vec.data(); d = vec.data();
} }
#endif #endif
VLOG(1) << name << " data_ptr = " << static_cast<void*>(d); VLOG(1) << name << " data_ptr = " << static_cast<const void*>(d);
std::string out; std::string out;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
out += std::to_string(d[i]); out += std::to_string(d[i]);
......
...@@ -349,827 +349,16 @@ class TestSampleLogitsOpV3(OpTest): ...@@ -349,827 +349,16 @@ class TestSampleLogitsOpV3(OpTest):
self.inputs = {'Logits': logits, 'Label': label} self.inputs = {'Logits': logits, 'Label': label}
def set_data(self, num_classes, num_samples, seed, remove_accidental_hits): def set_data(self, num_classes, num_samples, seed, remove_accidental_hits):
self.fetched_samples = np.array([[ label = [52, 2, 2, 17, 96, 2, 17, 96, 37, 2]
52, samples = [
3, 3, 12, 74, 28, 1, 79, 2, 42, 8, 13, 0, 18, 88, 49, 14, 46, 39, 57,
12, 26, 75, 9, 50, 16, 66, 6, 23, 5, 11, 17, 54, 35, 20, 53, 10, 47, 80,
74, 38, 7, 4, 31, 15, 19, 58, 22, 34, 41, 73, 62, 95, 25, 70, 37, 30,
28, 65, 27, 51, 43, 32, 99, 21, 56, 29, 40, 69, 55, 98, 77, 67, 33, 89,
1, 63, 81, 59, 48, 91, 68, 72, 61, 52, 86
79, ]
2,
42, self.fetched_samples = np.array([[x] + samples for x in label])
8,
13,
0,
18,
88,
49,
14,
46,
39,
57,
26,
75,
9,
50,
16,
66,
6,
23,
5,
11,
17,
54,
35,
20,
53,
10,
47,
80,
38,
7,
4,
31,
15,
19,
58,
22,
34,
41,
73,
62,
95,
25,
70,
37,
30,
65,
27,
51,
43,
32,
99,
21,
56,
29,
40,
69,
55,
98,
77,
67,
33,
89,
63,
81,
59,
48,
91,
68,
72,
61,
52,
86,
], [
2,
3,
12,
74,
28,
1,
79,
2,
42,
8,
13,
0,
18,
88,
49,
14,
46,
39,
57,
26,
75,
9,
50,
16,
66,
6,
23,
5,
11,
17,
54,
35,
20,
53,
10,
47,
80,
38,
7,
4,
31,
15,
19,
58,
22,
34,
41,
73,
62,
95,
25,
70,
37,
30,
65,
27,
51,
43,
32,
99,
21,
56,
29,
40,
69,
55,
98,
77,
67,
33,
89,
63,
81,
59,
48,
91,
68,
72,
61,
52,
86,
], [
2,
3,
12,
74,
28,
1,
79,
2,
42,
8,
13,
0,
18,
88,
49,
14,
46,
39,
57,
26,
75,
9,
50,
16,
66,
6,
23,
5,
11,
17,
54,
35,
20,
53,
10,
47,
80,
38,
7,
4,
31,
15,
19,
58,
22,
34,
41,
73,
62,
95,
25,
70,
37,
30,
65,
27,
51,
43,
32,
99,
21,
56,
29,
40,
69,
55,
98,
77,
67,
33,
89,
63,
81,
59,
48,
91,
68,
72,
61,
52,
86,
], [
17,
3,
12,
74,
28,
1,
79,
2,
42,
8,
13,
0,
18,
88,
49,
14,
46,
39,
57,
26,
75,
9,
50,
16,
66,
6,
23,
5,
11,
17,
54,
35,
20,
53,
10,
47,
80,
38,
7,
4,
31,
15,
19,
58,
22,
34,
41,
73,
62,
95,
25,
70,
37,
30,
65,
27,
51,
43,
32,
99,
21,
56,
29,
40,
69,
55,
98,
77,
67,
33,
89,
63,
81,
59,
48,
91,
68,
72,
61,
52,
86,
], [
96,
3,
12,
74,
28,
1,
79,
2,
42,
8,
13,
0,
18,
88,
49,
14,
46,
39,
57,
26,
75,
9,
50,
16,
66,
6,
23,
5,
11,
17,
54,
35,
20,
53,
10,
47,
80,
38,
7,
4,
31,
15,
19,
58,
22,
34,
41,
73,
62,
95,
25,
70,
37,
30,
65,
27,
51,
43,
32,
99,
21,
56,
29,
40,
69,
55,
98,
77,
67,
33,
89,
63,
81,
59,
48,
91,
68,
72,
61,
52,
86,
], [
2,
3,
12,
74,
28,
1,
79,
2,
42,
8,
13,
0,
18,
88,
49,
14,
46,
39,
57,
26,
75,
9,
50,
16,
66,
6,
23,
5,
11,
17,
54,
35,
20,
53,
10,
47,
80,
38,
7,
4,
31,
15,
19,
58,
22,
34,
41,
73,
62,
95,
25,
70,
37,
30,
65,
27,
51,
43,
32,
99,
21,
56,
29,
40,
69,
55,
98,
77,
67,
33,
89,
63,
81,
59,
48,
91,
68,
72,
61,
52,
86,
], [
17,
3,
12,
74,
28,
1,
79,
2,
42,
8,
13,
0,
18,
88,
49,
14,
46,
39,
57,
26,
75,
9,
50,
16,
66,
6,
23,
5,
11,
17,
54,
35,
20,
53,
10,
47,
80,
38,
7,
4,
31,
15,
19,
58,
22,
34,
41,
73,
62,
95,
25,
70,
37,
30,
65,
27,
51,
43,
32,
99,
21,
56,
29,
40,
69,
55,
98,
77,
67,
33,
89,
63,
81,
59,
48,
91,
68,
72,
61,
52,
86,
], [
96,
3,
12,
74,
28,
1,
79,
2,
42,
8,
13,
0,
18,
88,
49,
14,
46,
39,
57,
26,
75,
9,
50,
16,
66,
6,
23,
5,
11,
17,
54,
35,
20,
53,
10,
47,
80,
38,
7,
4,
31,
15,
19,
58,
22,
34,
41,
73,
62,
95,
25,
70,
37,
30,
65,
27,
51,
43,
32,
99,
21,
56,
29,
40,
69,
55,
98,
77,
67,
33,
89,
63,
81,
59,
48,
91,
68,
72,
61,
52,
86,
], [
37,
3,
12,
74,
28,
1,
79,
2,
42,
8,
13,
0,
18,
88,
49,
14,
46,
39,
57,
26,
75,
9,
50,
16,
66,
6,
23,
5,
11,
17,
54,
35,
20,
53,
10,
47,
80,
38,
7,
4,
31,
15,
19,
58,
22,
34,
41,
73,
62,
95,
25,
70,
37,
30,
65,
27,
51,
43,
32,
99,
21,
56,
29,
40,
69,
55,
98,
77,
67,
33,
89,
63,
81,
59,
48,
91,
68,
72,
61,
52,
86,
], [
2,
3,
12,
74,
28,
1,
79,
2,
42,
8,
13,
0,
18,
88,
49,
14,
46,
39,
57,
26,
75,
9,
50,
16,
66,
6,
23,
5,
11,
17,
54,
35,
20,
53,
10,
47,
80,
38,
7,
4,
31,
15,
19,
58,
22,
34,
41,
73,
62,
95,
25,
70,
37,
30,
65,
27,
51,
43,
32,
99,
21,
56,
29,
40,
69,
55,
98,
77,
67,
33,
89,
63,
81,
59,
48,
91,
68,
72,
61,
52,
86,
]])
fectched_num_tries = 323 fectched_num_tries = 323
label = self.fetched_samples[:, 0:1] label = self.fetched_samples[:, 0:1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册