Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
96096612
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
96096612
编写于
9月 23, 2020
作者:
H
haoyuying
提交者:
GitHub
9月 23, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Colorize (#897)
上级
5889f7cf
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
229 addition
and
62 deletion
+229
-62
demo/colorization/house.png
demo/colorization/house.png
+0
-0
demo/colorization/predict.py
demo/colorization/predict.py
+2
-3
demo/colorization/sea.jpg
demo/colorization/sea.jpg
+0
-0
demo/colorization/train.py
demo/colorization/train.py
+13
-7
hub_module/modules/image/colorization/user_guided_colorization/module.py
...les/image/colorization/user_guided_colorization/module.py
+14
-6
paddlehub/datasets/colorizedataset.py
paddlehub/datasets/colorizedataset.py
+4
-5
paddlehub/module/cv_module.py
paddlehub/module/cv_module.py
+86
-0
paddlehub/process/functional.py
paddlehub/process/functional.py
+21
-2
paddlehub/process/transforms.py
paddlehub/process/transforms.py
+89
-39
未找到文件。
demo/colorization/house.png
0 → 100644
浏览文件 @
96096612
228.9 KB
demo/colorization/predict.py
浏览文件 @
96096612
...
...
@@ -2,9 +2,8 @@ import paddle
import
paddlehub
as
hub
import
paddle.nn
as
nn
if
__name__
==
'__main__'
:
paddle
.
disable_static
()
model
=
hub
.
Module
(
directory
=
'user_guided_colorization'
)
model
=
hub
.
Module
(
name
=
'user_guided_colorization'
)
model
.
eval
()
result
=
model
.
predict
(
images
=
'sea.jpg'
)
\ No newline at end of file
result
=
model
.
predict
(
images
=
'house.png'
)
demo/colorization/sea.jpg
已删除
100644 → 0
浏览文件 @
5889f7cf
10.8 KB
demo/colorization/train.py
浏览文件 @
96096612
...
...
@@ -6,15 +6,21 @@ from paddlehub.finetune.trainer import Trainer
from
paddlehub.datasets.colorizedataset
import
Colorizedataset
from
paddlehub.process.transforms
import
Compose
,
Resize
,
RandomPaddingCrop
,
ConvertColorSpace
,
ColorizePreprocess
if
__name__
==
'__main__'
:
is_train
=
True
paddle
.
disable_static
()
model
=
hub
.
Module
(
directory
=
'user_guided_colorization'
)
transform
=
Compose
([
Resize
((
256
,
256
),
interp
=
"RANDOM"
),
RandomPaddingCrop
(
crop_size
=
176
),
ConvertColorSpace
(
mode
=
'RGB2LAB'
),
ColorizePreprocess
(
ab_thresh
=
0
,
p
=
1
)],
stay_rgb
=
True
)
color_set
=
Colorizedataset
(
transform
=
transform
,
mode
=
is_train
)
model
=
hub
.
Module
(
name
=
'user_guided_colorization'
)
transform
=
Compose
([
Resize
((
256
,
256
),
interp
=
'NEAREST'
),
RandomPaddingCrop
(
crop_size
=
176
),
ConvertColorSpace
(
mode
=
'RGB2LAB'
),
ColorizePreprocess
(
ab_thresh
=
0
,
is_train
=
is_train
),
],
stay_rgb
=
True
,
is_permute
=
False
)
color_set
=
Colorizedataset
(
transform
=
transform
,
mode
=
'train'
)
if
is_train
:
model
.
train
()
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.0001
,
parameters
=
model
.
parameters
())
trainer
=
Trainer
(
model
,
optimizer
,
checkpoint_dir
=
'test_ckpt_img_cls'
)
trainer
.
train
(
color_set
,
epochs
=
3
,
batch_size
=
1
,
eval_dataset
=
color_set
,
save_interval
=
1
)
trainer
.
train
(
color_set
,
epochs
=
101
,
batch_size
=
5
,
eval_dataset
=
color_set
,
log_interval
=
10
,
save_interval
=
10
)
hub_module/modules/image/colorization/user_guided_colorization/module.py
浏览文件 @
96096612
...
...
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
numpy
import
paddle.nn
as
nn
from
paddle.nn
import
Conv2d
,
ConvTranspose2d
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.process.transforms
import
Compose
,
Resize
,
RandomPaddingCrop
,
ConvertColorSpace
,
ColorizePreprocess
from
paddlehub.module.cv_module
import
ImageColorizeModule
...
...
@@ -178,24 +179,31 @@ class UserGuidedColorization(nn.Layer):
if
load_checkpoint
is
not
None
:
model_dict
=
paddle
.
load
(
load_checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load pretrained model success"
)
print
(
"load custom checkpoint success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'user_guided.pdparams'
)
model_dict
=
paddle
.
load
(
checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
transforms
(
self
,
images
:
str
,
is_train
:
bool
=
True
)
->
callable
:
if
is_train
:
transform
=
Compose
([
Resize
((
256
,
256
),
interp
=
"RANDOM"
),
Resize
((
256
,
256
),
interp
=
'NEAREST'
),
RandomPaddingCrop
(
crop_size
=
176
),
ConvertColorSpace
(
mode
=
'RGB2LAB'
),
ColorizePreprocess
(
ab_thresh
=
0
,
is_train
=
is_train
)
],
stay_rgb
=
True
)
stay_rgb
=
True
,
is_permute
=
False
)
else
:
transform
=
Compose
([
Resize
((
256
,
256
),
interp
=
"RANDOM"
),
Resize
((
256
,
256
),
interp
=
'NEAREST'
),
ConvertColorSpace
(
mode
=
'RGB2LAB'
),
ColorizePreprocess
(
ab_thresh
=
0
,
is_train
=
is_train
)
],
stay_rgb
=
True
)
stay_rgb
=
True
,
is_permute
=
False
)
return
transform
(
images
)
def
forward
(
self
,
...
...
paddlehub/datasets/colorizedataset.py
浏览文件 @
96096612
...
...
@@ -22,6 +22,7 @@ from paddlehub.process.functional import get_img_file
from
paddlehub.env
import
DATA_HOME
from
typing
import
Callable
class
Colorizedataset
(
paddle
.
io
.
Dataset
):
"""
Dataset for colorization.
...
...
@@ -39,8 +40,6 @@ class Colorizedataset(paddle.io.Dataset):
self
.
file
=
'train'
elif
self
.
mode
==
'test'
:
self
.
file
=
'test'
else
:
self
.
file
=
'validation'
self
.
file
=
os
.
path
.
join
(
DATA_HOME
,
'canvas'
,
self
.
file
)
self
.
data
=
get_img_file
(
self
.
file
)
...
...
paddlehub/module/cv_module.py
浏览文件 @
96096612
...
...
@@ -18,6 +18,7 @@ import os
from
typing
import
List
from
collections
import
OrderedDict
import
cv2
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
...
...
@@ -27,6 +28,7 @@ from PIL import Image
from
paddlehub.module.module
import
serving
,
RunModule
from
paddlehub.utils.utils
import
base64_to_cv2
from
paddlehub.process.transforms
import
ConvertColorSpace
,
ColorPostprocess
,
Resize
from
paddlehub.process.functional
import
subtract_imagenet_mean_batch
,
gram_matrix
class
ImageServing
(
object
):
...
...
@@ -192,3 +194,87 @@ class ImageColorizeModule(RunModule, ImageServing):
psnr_value
=
20
*
np
.
log10
(
255.
/
np
.
sqrt
(
mse
))
result
.
append
(
visual_ret
)
return
result
class
StyleTransferModule
(
RunModule
,
ImageServing
):
def
training_step
(
self
,
batch
:
int
,
batch_idx
:
int
)
->
dict
:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
'''
return
self
.
validation_step
(
batch
,
batch_idx
)
def
validation_step
(
self
,
batch
:
int
,
batch_idx
:
int
)
->
dict
:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
mse_loss
=
nn
.
MSELoss
()
N
,
C
,
H
,
W
=
batch
[
0
].
shape
batch
[
1
]
=
batch
[
1
][
0
].
unsqueeze
(
0
)
self
.
setTarget
(
batch
[
1
])
y
=
self
(
batch
[
0
])
xc
=
paddle
.
to_tensor
(
batch
[
0
].
numpy
().
copy
())
y
=
subtract_imagenet_mean_batch
(
y
)
xc
=
subtract_imagenet_mean_batch
(
xc
)
features_y
=
self
.
getFeature
(
y
)
features_xc
=
self
.
getFeature
(
xc
)
f_xc_c
=
paddle
.
to_tensor
(
features_xc
[
1
].
numpy
(),
stop_gradient
=
True
)
content_loss
=
mse_loss
(
features_y
[
1
],
f_xc_c
)
batch
[
1
]
=
subtract_imagenet_mean_batch
(
batch
[
1
])
features_style
=
self
.
getFeature
(
batch
[
1
])
gram_style
=
[
gram_matrix
(
y
)
for
y
in
features_style
]
style_loss
=
0.
for
m
in
range
(
len
(
features_y
)):
gram_y
=
gram_matrix
(
features_y
[
m
])
gram_s
=
paddle
.
to_tensor
(
np
.
tile
(
gram_style
[
m
].
numpy
(),
(
N
,
1
,
1
,
1
)))
style_loss
+=
mse_loss
(
gram_y
,
gram_s
[:
N
,
:,
:])
loss
=
content_loss
+
style_loss
return
{
'loss'
:
loss
,
'metrics'
:
{
'content gap'
:
content_loss
,
'style gap'
:
style_loss
}}
def
predict
(
self
,
origin_path
:
str
,
style_path
:
str
,
visualization
:
bool
=
True
,
save_path
:
str
=
'result'
):
'''
Colorize images
Args:
origin_path(str): Content image path .
style_path(str): Style image path.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
Returns:
output(np.ndarray) : The style transformed images with bgr mode.
'''
content
=
paddle
.
to_tensor
(
self
.
transform
(
origin_path
))
style
=
paddle
.
to_tensor
(
self
.
transform
(
style_path
))
content
=
content
.
unsqueeze
(
0
)
style
=
style
.
unsqueeze
(
0
)
self
.
setTarget
(
style
)
output
=
self
(
content
)
output
=
paddle
.
clip
(
output
[
0
].
transpose
((
1
,
2
,
0
)),
0
,
255
).
numpy
()
if
visualization
:
output
=
output
.
astype
(
np
.
uint8
)
style_name
=
"style_"
+
str
(
time
.
time
())
+
".png"
if
not
os
.
path
.
exists
(
save_path
):
os
.
mkdir
(
save_path
)
path
=
os
.
path
.
join
(
save_path
,
style_name
)
cv2
.
imwrite
(
path
,
output
)
return
output
paddlehub/process/functional.py
浏览文件 @
96096612
...
...
@@ -15,6 +15,7 @@
import
os
import
cv2
import
paddle
import
numpy
as
np
from
PIL
import
Image
,
ImageEnhance
...
...
@@ -114,7 +115,25 @@ def get_img_file(dir_name: str) -> list:
if
not
is_image_file
(
filename
):
continue
img_path
=
os
.
path
.
join
(
parent
,
filename
)
print
(
img_path
)
images
.
append
(
img_path
)
images
.
sort
()
return
images
def
subtract_imagenet_mean_batch
(
batch
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Subtract ImageNet mean pixel-wise from a BGR image."""
mean
=
np
.
zeros
(
shape
=
batch
.
shape
,
dtype
=
'float32'
)
mean
[:,
0
,
:,
:]
=
103.939
mean
[:,
1
,
:,
:]
=
116.779
mean
[:,
2
,
:,
:]
=
123.680
mean
=
paddle
.
to_tensor
(
mean
)
return
batch
-
mean
def
gram_matrix
(
data
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Get gram matrix"""
b
,
ch
,
h
,
w
=
data
.
shape
features
=
data
.
reshape
((
b
,
ch
,
w
*
h
))
features_t
=
features
.
transpose
((
0
,
2
,
1
))
gram
=
features
.
bmm
(
features_t
)
/
(
ch
*
h
*
w
)
return
gram
paddlehub/process/transforms.py
浏览文件 @
96096612
...
...
@@ -24,7 +24,7 @@ from paddlehub.process.functional import *
class
Compose
:
def
__init__
(
self
,
transforms
,
to_rgb
=
True
,
stay_rgb
=
False
):
def
__init__
(
self
,
transforms
,
to_rgb
=
True
,
stay_rgb
=
False
,
is_permute
=
True
):
if
not
isinstance
(
transforms
,
list
):
raise
TypeError
(
'The transforms must be a list!'
)
if
len
(
transforms
)
<
1
:
...
...
@@ -33,6 +33,7 @@ class Compose:
self
.
transforms
=
transforms
self
.
to_rgb
=
to_rgb
self
.
stay_rgb
=
stay_rgb
self
.
is_permute
=
is_permute
def
__call__
(
self
,
im
):
if
isinstance
(
im
,
str
):
...
...
@@ -47,13 +48,14 @@ class Compose:
im
=
op
(
im
)
if
not
self
.
stay_rgb
:
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
is_permute
:
im
=
permute
(
im
)
return
im
class
RandomHorizontalFlip
:
def
__init__
(
self
,
prob
=
0.5
):
self
.
prob
=
prob
...
...
@@ -239,8 +241,13 @@ class RandomPaddingCrop:
pad_height
=
max
(
crop_height
-
img_height
,
0
)
pad_width
=
max
(
crop_width
-
img_width
,
0
)
if
(
pad_height
>
0
or
pad_width
>
0
):
im
=
cv2
.
copyMakeBorder
(
im
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
im_padding_value
)
im
=
cv2
.
copyMakeBorder
(
im
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
im_padding_value
)
if
crop_height
>
0
and
crop_width
>
0
:
h_off
=
np
.
random
.
randint
(
img_height
-
crop_height
+
1
)
...
...
@@ -295,8 +302,7 @@ class RandomRotation:
r
[
0
,
2
]
+=
(
nw
/
2
)
-
cx
r
[
1
,
2
]
+=
(
nh
/
2
)
-
cy
dsize
=
(
nw
,
nh
)
im
=
cv2
.
warpAffine
(
im
,
im
=
cv2
.
warpAffine
(
im
,
r
,
dsize
=
dsize
,
flags
=
cv2
.
INTER_LINEAR
,
...
...
@@ -429,7 +435,7 @@ class ConvertColorSpace:
"""
mask
=
(
rgb
>
0.04045
)
np
.
seterr
(
invalid
=
'ignore'
)
rgb
=
(((
rgb
+
.
055
)
/
1.055
)
**
2.4
)
*
mask
+
rgb
/
12.92
*
(
1
-
mask
)
rgb
=
(((
rgb
+
.
055
)
/
1.055
)
**
2.4
)
*
mask
+
rgb
/
12.92
*
(
1
-
mask
)
rgb
=
np
.
nan_to_num
(
rgb
)
x
=
.
412453
*
rgb
[:,
0
,
:,
:]
+
.
357580
*
rgb
[:,
1
,
:,
:]
+
.
180423
*
rgb
[:,
2
,
:,
:]
y
=
.
212671
*
rgb
[:,
0
,
:,
:]
+
.
715160
*
rgb
[:,
1
,
:,
:]
+
.
072169
*
rgb
[:,
2
,
:,
:]
...
...
@@ -490,7 +496,7 @@ class ConvertColorSpace:
rgb
=
np
.
maximum
(
rgb
,
0
)
# sometimes reaches a small negative number, which causes NaNs
mask
=
(
rgb
>
.
0031308
).
astype
(
np
.
float32
)
np
.
seterr
(
invalid
=
'ignore'
)
out
=
(
1.055
*
(
rgb
**
(
1.
/
2.4
))
-
0.055
)
*
mask
+
12.92
*
rgb
*
(
1
-
mask
)
out
=
(
1.055
*
(
rgb
**
(
1.
/
2.4
))
-
0.055
)
*
mask
+
12.92
*
rgb
*
(
1
-
mask
)
out
=
np
.
nan_to_num
(
out
)
return
out
...
...
@@ -511,7 +517,7 @@ class ConvertColorSpace:
out
=
np
.
concatenate
((
x_int
[:,
None
,
:,
:],
y_int
[:,
None
,
:,
:],
z_int
[:,
None
,
:,
:]),
axis
=
1
)
mask
=
(
out
>
.
2068966
).
astype
(
np
.
float32
)
np
.
seterr
(
invalid
=
'ignore'
)
out
=
(
out
**
3.
)
*
mask
+
(
out
-
16.
/
116.
)
/
7.787
*
(
1
-
mask
)
out
=
(
out
**
3.
)
*
mask
+
(
out
-
16.
/
116.
)
/
7.787
*
(
1
-
mask
)
out
=
np
.
nan_to_num
(
out
)
sc
=
np
.
array
((
0.95047
,
1.
,
1.08883
))[
None
,
:,
None
,
None
]
out
=
out
*
sc
...
...
@@ -566,7 +572,7 @@ class ColorizeHint:
self
.
use_avg
=
use_avg
def
__call__
(
self
,
data
:
np
.
ndarray
,
hint
:
np
.
ndarray
,
mask
:
np
.
ndarray
):
sample_Ps
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
]
sample_Ps
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]
self
.
data
=
data
self
.
hint
=
hint
self
.
mask
=
mask
...
...
@@ -577,7 +583,7 @@ class ColorizeHint:
while
cont_cond
:
if
self
.
num_points
is
None
:
# draw from geometric
# embed()
cont_cond
=
np
.
random
.
rand
()
<
(
1
-
self
.
percent
)
cont_cond
=
np
.
random
.
rand
()
>
(
1
-
self
.
percent
)
else
:
# add certain number of points
cont_cond
=
pp
<
self
.
num_points
if
not
cont_cond
:
# skip out of loop if condition not met
...
...
@@ -593,9 +599,11 @@ class ColorizeHint:
# add color point
if
self
.
use_avg
:
# embed()
hint
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
]
=
np
.
mean
(
np
.
mean
(
data
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
],
axis
=
2
,
keepdims
=
True
),
axis
=
1
,
keepdims
=
True
).
reshape
(
1
,
C
,
1
,
1
)
hint
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
]
=
np
.
mean
(
np
.
mean
(
data
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
],
axis
=
2
,
keepdims
=
True
),
axis
=
1
,
keepdims
=
True
).
reshape
(
1
,
C
,
1
,
1
)
else
:
hint
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
]
=
data
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
]
mask
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
]
=
1
...
...
@@ -641,8 +649,9 @@ class ColorizePreprocess:
data(dict):The preprocessed data for colorization.
"""
def
__init__
(
self
,
ab_thresh
:
float
=
0.
,
p
:
float
=
.
125
,
def
__init__
(
self
,
ab_thresh
:
float
=
0.
,
p
:
float
=
0.
,
num_points
:
int
=
None
,
samp
:
str
=
'normal'
,
use_avg
:
bool
=
True
,
...
...
@@ -668,11 +677,14 @@ class ColorizePreprocess:
"""
data
=
{}
A
=
2
*
110
/
10
+
1
data
[
'A'
]
=
data_lab
[:,
[
0
,
],
:,
:]
data
[
'A'
]
=
data_lab
[:,
[
0
,
],
:,
:]
data
[
'B'
]
=
data_lab
[:,
1
:,
:,
:]
if
self
.
ab_thresh
>
0
:
# mask out grayscale images
thresh
=
1.
*
self
.
ab_thresh
/
110
mask
=
np
.
sum
(
np
.
abs
(
np
.
max
(
np
.
max
(
data
[
'B'
],
axis
=
3
),
axis
=
2
)
-
np
.
min
(
np
.
min
(
data
[
'B'
],
axis
=
3
),
axis
=
2
)),
axis
=
1
)
mask
=
np
.
sum
(
np
.
abs
(
np
.
max
(
np
.
max
(
data
[
'B'
],
axis
=
3
),
axis
=
2
)
-
np
.
min
(
np
.
min
(
data
[
'B'
],
axis
=
3
),
axis
=
2
)),
axis
=
1
)
mask
=
(
mask
>=
thresh
)
data
[
'A'
]
=
data
[
'A'
][
mask
,
:,
:,
:]
data
[
'B'
]
=
data
[
'B'
][
mask
,
:,
:,
:]
...
...
@@ -713,3 +725,41 @@ class ColorPostprocess:
img
=
np
.
clip
(
img
,
0
,
1
)
*
255
img
=
img
.
astype
(
self
.
type
)
return
img
class
CenterCrop
:
"""
Crop the middle part of the image to the specified size.
Args:
crop_size(int): Crop size.
Return:
img(np.ndarray): Croped image.
"""
def
__init__
(
self
,
crop_size
:
int
):
self
.
crop_size
=
crop_size
def
__call__
(
self
,
img
:
np
.
ndarray
):
img_width
,
img_height
,
chanel
=
img
.
shape
crop_top
=
int
((
img_height
-
self
.
crop_size
)
/
2.
)
crop_left
=
int
((
img_width
-
self
.
crop_size
)
/
2.
)
return
img
[
crop_left
:
crop_left
+
self
.
crop_size
,
crop_top
:
crop_top
+
self
.
crop_size
,
:]
class
SetType
:
"""
Set image type.
Args:
type(type): Type of Image value.
Return:
img(np.ndarray): Transformed image.
"""
def
__init__
(
self
,
datatype
:
type
=
'float32'
):
self
.
type
=
datatype
def
__call__
(
self
,
img
:
np
.
ndarray
):
img
=
img
.
astype
(
self
.
type
)
return
img
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录