Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1eb30775
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1eb30775
编写于
4月 18, 2023
作者:
C
chenxujun
提交者:
GitHub
4月 18, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add index_add, index_sample, put_along_axis, take_along_axis tests (#52572)
上级
afc2c598
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
280 addition
and
26 deletion
+280
-26
paddle/phi/kernels/funcs/gather_scatter_functor.h
paddle/phi/kernels/funcs/gather_scatter_functor.h
+8
-6
paddle/phi/kernels/gpu/index_add_grad_kernel.cu
paddle/phi/kernels/gpu/index_add_grad_kernel.cu
+1
-0
paddle/phi/kernels/gpu/index_add_kernel.cu
paddle/phi/kernels/gpu/index_add_kernel.cu
+1
-0
paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu
paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu
+2
-1
paddle/phi/kernels/gpu/put_along_axis_kernel.cu
paddle/phi/kernels/gpu/put_along_axis_kernel.cu
+2
-1
paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu
paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu
+2
-1
paddle/phi/kernels/gpu/take_along_axis_kernel.cu
paddle/phi/kernels/gpu/take_along_axis_kernel.cu
+2
-1
python/paddle/fluid/tests/unittests/test_index_add_op.py
python/paddle/fluid/tests/unittests/test_index_add_op.py
+65
-2
python/paddle/fluid/tests/unittests/test_index_sample_op.py
python/paddle/fluid/tests/unittests/test_index_sample_op.py
+45
-1
python/paddle/fluid/tests/unittests/test_put_along_axis_op.py
...on/paddle/fluid/tests/unittests/test_put_along_axis_op.py
+71
-6
python/paddle/fluid/tests/unittests/test_take_along_axis_op.py
...n/paddle/fluid/tests/unittests/test_take_along_axis_op.py
+61
-3
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+20
-4
未找到文件。
paddle/phi/kernels/funcs/gather_scatter_functor.h
浏览文件 @
1eb30775
...
...
@@ -21,12 +21,14 @@ limitations under the License. */
namespace
phi
{
namespace
funcs
{
#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t(func, double) \
Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t( \
func, double) Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
Instantiate_Template_Function_index_t(func, \
phi::dtype::bfloat16) \
Instantiate_Template_Function_index_t(func, unsigned char)
#define Instantiate_Template_Function_index_t(func, tensor_t) \
...
...
paddle/phi/kernels/gpu/index_add_grad_kernel.cu
浏览文件 @
1eb30775
...
...
@@ -105,5 +105,6 @@ PD_REGISTER_KERNEL(index_add_grad,
float
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
int
,
int64_t
)
{}
paddle/phi/kernels/gpu/index_add_kernel.cu
浏览文件 @
1eb30775
...
...
@@ -123,5 +123,6 @@ PD_REGISTER_KERNEL(index_add,
float
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
int
,
int64_t
)
{}
paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu
浏览文件 @
1eb30775
...
...
@@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(put_along_axis_grad,
double
,
int64_t
,
int
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/put_along_axis_kernel.cu
浏览文件 @
1eb30775
...
...
@@ -82,4 +82,5 @@ PD_REGISTER_KERNEL(put_along_axis,
double
,
int64_t
,
int
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/take_along_axis_grad_kernel.cu
浏览文件 @
1eb30775
...
...
@@ -68,4 +68,5 @@ PD_REGISTER_KERNEL(take_along_axis_grad,
double
,
int64_t
,
int
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/take_along_axis_kernel.cu
浏览文件 @
1eb30775
...
...
@@ -54,4 +54,5 @@ PD_REGISTER_KERNEL(take_along_axis,
double
,
int64_t
,
int
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
python/paddle/fluid/tests/unittests/test_index_add_op.py
浏览文件 @
1eb30775
...
...
@@ -15,10 +15,10 @@
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
import
paddle
from
paddle.fluid
import
Program
from
paddle.fluid
import
Program
,
core
def
compute_index_add_ref
(
...
...
@@ -99,6 +99,69 @@ class TestIndexAddOp(OpTest):
self
.
check_grad
([
'X'
,
'AddValue'
],
'Out'
)
class
TestIndexAddFP16Op
(
TestIndexAddOp
):
def
init_dtype_type
(
self
):
self
.
axis
=
0
self
.
x_type
=
np
.
float16
self
.
index_type
=
np
.
int64
self
.
x_shape
=
(
101
,
3
)
self
.
index_size
=
3
self
.
add_value_shape
=
(
3
,
3
)
self
.
dtype
=
np
.
float16
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not compiled with CUDA or not support bfloat16"
,
)
class
TestIndexAddBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
python_api
=
raw_index_add
self
.
op_type
=
"index_add"
self
.
init_dtype_type
()
index_np
=
np
.
random
.
randint
(
low
=
0
,
high
=
self
.
x_shape
[
self
.
axis
],
size
=
self
.
index_size
)
x_np
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
x_type
)
add_value_np
=
np
.
random
.
random
(
self
.
add_value_shape
).
astype
(
self
.
x_type
)
self
.
inputs
=
{
'X'
:
convert_float_to_uint16
(
x_np
),
'Index'
:
index_np
,
'AddValue'
:
convert_float_to_uint16
(
add_value_np
),
}
self
.
attrs
=
{
'axis'
:
self
.
axis
}
out
=
compute_index_add_ref
(
self
.
axis
,
self
.
x_shape
,
x_np
,
self
.
add_value_shape
,
add_value_np
,
self
.
index_size
,
index_np
,
)
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
out
)}
self
.
place
=
core
.
CUDAPlace
(
0
)
def
init_dtype_type
(
self
):
self
.
axis
=
0
self
.
x_type
=
np
.
float32
self
.
index_type
=
np
.
int64
self
.
x_shape
=
(
101
,
3
)
self
.
index_size
=
3
self
.
add_value_shape
=
(
3
,
3
)
self
.
dtype
=
np
.
uint16
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad_normal
(
self
):
self
.
check_grad_with_place
(
self
.
place
,
[
'X'
,
'AddValue'
],
'Out'
)
class
TestIndexAddAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
setType
()
...
...
python/paddle/fluid/tests/unittests/test_index_sample_op.py
浏览文件 @
1eb30775
...
...
@@ -15,10 +15,11 @@
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
import
paddle
from
paddle
import
fluid
from
paddle.fluid
import
core
class
TestIndexSampleOp
(
OpTest
):
...
...
@@ -121,6 +122,49 @@ class TestCase6(TestIndexSampleOp):
self
.
index_type
=
"int64"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not compiled with CUDA or not support bfloat16"
,
)
class
TestIndexSampleBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"index_sample"
self
.
python_api
=
paddle
.
index_sample
self
.
config
()
xnp
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
x_type
)
indexnp
=
np
.
random
.
randint
(
low
=
0
,
high
=
self
.
x_shape
[
1
],
size
=
self
.
index_shape
).
astype
(
self
.
index_type
)
self
.
inputs
=
{
'X'
:
xnp
,
'Index'
:
indexnp
}
index_array
=
[]
for
i
in
range
(
self
.
index_shape
[
0
]):
for
j
in
indexnp
[
i
]:
index_array
.
append
(
xnp
[
i
,
j
])
index_array
=
np
.
array
(
index_array
).
astype
(
self
.
x_type
)
out
=
np
.
reshape
(
index_array
,
self
.
index_shape
)
self
.
outputs
=
{
'Out'
:
out
}
self
.
inputs
[
'X'
]
=
convert_float_to_uint16
(
self
.
inputs
[
'X'
])
self
.
outputs
[
'Out'
]
=
convert_float_to_uint16
(
self
.
outputs
[
'Out'
])
self
.
place
=
core
.
CUDAPlace
(
0
)
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
self
.
place
,
[
'X'
],
'Out'
)
def
config
(
self
):
"""
For multi-dimension input
"""
self
.
x_shape
=
(
10
,
20
)
self
.
x_type
=
"float32"
self
.
dtype
=
np
.
uint16
self
.
index_shape
=
(
10
,
10
)
self
.
index_type
=
"int32"
class
TestIndexSampleShape
(
unittest
.
TestCase
):
def
test_shape
(
self
):
paddle
.
enable_static
()
...
...
python/paddle/fluid/tests/unittests/test_put_along_axis_op.py
浏览文件 @
1eb30775
...
...
@@ -16,7 +16,7 @@ import copy
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
import
paddle
from
paddle.framework
import
core
...
...
@@ -28,19 +28,18 @@ class TestPutAlongAxisOp(OpTest):
def
setUp
(
self
):
self
.
init_data
()
self
.
reduce_op
=
"assign"
self
.
dtype
=
'float64'
self
.
op_type
=
"put_along_axis"
self
.
python_api
=
paddle
.
tensor
.
put_along_axis
self
.
xnp
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
x_type
)
# numpy put_along_axis is an inplace ope
ar
ion.
# numpy put_along_axis is an inplace ope
rat
ion.
self
.
xnp_result
=
copy
.
deepcopy
(
self
.
xnp
)
np
.
put_along_axis
(
self
.
xnp_result
,
self
.
index
,
self
.
value
,
self
.
axis
)
self
.
target
=
self
.
xnp_result
broadcast_shape_list
=
list
(
self
.
x_shape
)
broadcast_shape_list
[
self
.
axis
]
=
1
self
.
br
ao
dcast_shape
=
tuple
(
broadcast_shape_list
)
self
.
index_broadcast
=
np
.
broadcast_to
(
self
.
index
,
self
.
br
ao
dcast_shape
)
self
.
value_broadcast
=
np
.
broadcast_to
(
self
.
value
,
self
.
br
ao
dcast_shape
)
self
.
br
oa
dcast_shape
=
tuple
(
broadcast_shape_list
)
self
.
index_broadcast
=
np
.
broadcast_to
(
self
.
index
,
self
.
br
oa
dcast_shape
)
self
.
value_broadcast
=
np
.
broadcast_to
(
self
.
value
,
self
.
br
oa
dcast_shape
)
self
.
inputs
=
{
'Input'
:
self
.
xnp
,
'Index'
:
self
.
index_broadcast
,
...
...
@@ -56,6 +55,7 @@ class TestPutAlongAxisOp(OpTest):
self
.
check_grad
([
"Input"
,
"Value"
],
"Result"
)
def
init_data
(
self
):
self
.
dtype
=
'float64'
self
.
x_type
=
"float64"
self
.
x_shape
=
(
10
,
10
,
10
)
self
.
value_type
=
"float64"
...
...
@@ -66,6 +66,71 @@ class TestPutAlongAxisOp(OpTest):
self
.
axis_type
=
"int64"
class
TestPutAlongAxisFP16Op
(
TestPutAlongAxisOp
):
def
init_data
(
self
):
self
.
dtype
=
np
.
float16
self
.
x_type
=
"float16"
self
.
x_shape
=
(
10
,
10
,
10
)
self
.
value_type
=
"float16"
self
.
value
=
np
.
array
([
99
]).
astype
(
self
.
value_type
)
self
.
index_type
=
"int32"
self
.
index
=
np
.
array
([[[
0
]]]).
astype
(
self
.
index_type
)
self
.
axis
=
1
self
.
axis_type
=
"int64"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestPutAlongAxisBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
init_data
()
self
.
reduce_op
=
"assign"
self
.
op_type
=
"put_along_axis"
self
.
python_api
=
paddle
.
tensor
.
put_along_axis
self
.
xnp
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
x_type
)
# numpy put_along_axis is an inplace operation.
self
.
xnp_result
=
copy
.
deepcopy
(
self
.
xnp
)
np
.
put_along_axis
(
self
.
xnp_result
,
self
.
index
,
self
.
value
,
self
.
axis
)
self
.
target
=
self
.
xnp_result
broadcast_shape_list
=
list
(
self
.
x_shape
)
broadcast_shape_list
[
self
.
axis
]
=
1
self
.
broadcast_shape
=
tuple
(
broadcast_shape_list
)
self
.
index_broadcast
=
np
.
broadcast_to
(
self
.
index
,
self
.
broadcast_shape
)
self
.
value_broadcast
=
np
.
broadcast_to
(
self
.
value
,
self
.
broadcast_shape
)
self
.
inputs
=
{
'Input'
:
self
.
xnp
,
'Index'
:
self
.
index_broadcast
,
'Value'
:
self
.
value_broadcast
,
}
self
.
attrs
=
{
'Axis'
:
self
.
axis
,
'Reduce'
:
self
.
reduce_op
}
self
.
outputs
=
{
'Result'
:
self
.
target
}
self
.
inputs
[
'Input'
]
=
convert_float_to_uint16
(
self
.
inputs
[
'Input'
])
self
.
inputs
[
'Value'
]
=
convert_float_to_uint16
(
self
.
inputs
[
'Value'
])
self
.
outputs
[
'Result'
]
=
convert_float_to_uint16
(
self
.
outputs
[
'Result'
])
self
.
place
=
core
.
CUDAPlace
(
0
)
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
self
.
place
,
[
"Input"
,
"Value"
],
"Result"
)
def
init_data
(
self
):
self
.
dtype
=
np
.
uint16
self
.
x_type
=
"float32"
self
.
x_shape
=
(
10
,
10
,
10
)
self
.
value_type
=
"float32"
self
.
value
=
np
.
array
([
99
]).
astype
(
self
.
value_type
)
self
.
index_type
=
"int32"
self
.
index
=
np
.
array
([[[
0
]]]).
astype
(
self
.
index_type
)
self
.
axis
=
1
self
.
axis_type
=
"int64"
class
TestPutAlongAxisAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
np
.
random
.
seed
(
0
)
...
...
python/paddle/fluid/tests/unittests/test_take_along_axis_op.py
浏览文件 @
1eb30775
...
...
@@ -15,7 +15,7 @@
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
import
paddle
from
paddle.framework
import
core
...
...
@@ -32,8 +32,8 @@ class TestTakeAlongAxisOp(OpTest):
self
.
target
=
np
.
take_along_axis
(
self
.
xnp
,
self
.
index
,
self
.
axis
)
broadcast_shape_list
=
list
(
self
.
x_shape
)
broadcast_shape_list
[
self
.
axis
]
=
1
self
.
br
ao
dcast_shape
=
tuple
(
broadcast_shape_list
)
self
.
index_broadcast
=
np
.
broadcast_to
(
self
.
index
,
self
.
br
ao
dcast_shape
)
self
.
br
oa
dcast_shape
=
tuple
(
broadcast_shape_list
)
self
.
index_broadcast
=
np
.
broadcast_to
(
self
.
index
,
self
.
br
oa
dcast_shape
)
self
.
inputs
=
{
'Input'
:
self
.
xnp
,
'Index'
:
self
.
index_broadcast
,
...
...
@@ -58,6 +58,64 @@ class TestTakeAlongAxisOp(OpTest):
self
.
axis_type
=
"int64"
class
TestTakeAlongAxisFP16Op
(
TestTakeAlongAxisOp
):
def
init_data
(
self
):
self
.
dtype
=
np
.
float16
self
.
x_type
=
"float16"
self
.
x_shape
=
(
5
,
5
,
5
)
self
.
index_type
=
"int32"
self
.
index
=
np
.
array
([[[
1
]],
[[
1
]],
[[
2
]],
[[
4
]],
[[
3
]]]).
astype
(
self
.
index_type
)
self
.
axis
=
2
self
.
axis_type
=
"int64"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestTakeAlongAxisBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
init_data
()
self
.
op_type
=
"take_along_axis"
self
.
python_api
=
paddle
.
tensor
.
take_along_axis
self
.
xnp
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
x_type
)
self
.
target
=
np
.
take_along_axis
(
self
.
xnp
,
self
.
index
,
self
.
axis
)
broadcast_shape_list
=
list
(
self
.
x_shape
)
broadcast_shape_list
[
self
.
axis
]
=
1
self
.
broadcast_shape
=
tuple
(
broadcast_shape_list
)
self
.
index_broadcast
=
np
.
broadcast_to
(
self
.
index
,
self
.
broadcast_shape
)
self
.
inputs
=
{
'Input'
:
self
.
xnp
,
'Index'
:
self
.
index_broadcast
,
}
self
.
attrs
=
{
'Axis'
:
self
.
axis
}
self
.
outputs
=
{
'Result'
:
self
.
target
}
self
.
inputs
[
'Input'
]
=
convert_float_to_uint16
(
self
.
inputs
[
'Input'
])
self
.
outputs
[
'Result'
]
=
convert_float_to_uint16
(
self
.
outputs
[
'Result'
])
self
.
place
=
core
.
CUDAPlace
(
0
)
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
self
.
place
,
[
'Input'
],
'Result'
)
def
init_data
(
self
):
self
.
dtype
=
np
.
uint16
self
.
x_type
=
"float32"
self
.
x_shape
=
(
5
,
5
,
5
)
self
.
index_type
=
"int32"
self
.
index
=
np
.
array
([[[
1
]],
[[
1
]],
[[
2
]],
[[
4
]],
[[
3
]]]).
astype
(
self
.
index_type
)
self
.
axis
=
2
self
.
axis_type
=
"int64"
class
TestCase1
(
TestTakeAlongAxisOp
):
def
init_data
(
self
):
self
.
x_type
=
"float64"
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
1eb30775
...
...
@@ -4540,7 +4540,15 @@ def take_along_axis(arr, indices, axis):
check_variable_and_dtype
(
arr
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint8'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint8'
,
'uint16'
,
],
'take_along_axis'
,
)
check_variable_and_dtype
(
...
...
@@ -4612,7 +4620,15 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
check_variable_and_dtype
(
arr
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint8'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint8'
,
'uint16'
,
],
'put_along_axis'
,
)
check_variable_and_dtype
(
...
...
@@ -4694,7 +4710,7 @@ def index_add(x, index, axis, value, name=None):
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'paddle.tensor.manipulation.index_add'
,
)
check_variable_and_dtype
(
...
...
@@ -4706,7 +4722,7 @@ def index_add(x, index, axis, value, name=None):
check_variable_and_dtype
(
value
,
'add_value'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'paddle.tensor.manipulation.index_add'
,
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录