未验证 提交 1716324c 编写于 作者: W wuhuanzhou 提交者: GitHub

Add paddle.lerp API to do a linear interpolation (#37253)

* save temp

* add unittest, test=develop

* fix ci error, test=develop

* fix grad accuracy error, test=develop

* fix unused error, test=develop

* fix compilation error on Windows, test=develop

* add unittest, test=develop

* modify by review comment and add lerp_

* fix inplace api, test=develop

* fix inplace api, test=develop

* fix coverage error, test=develop
上级 46212b80
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/lerp_op.h"
namespace paddle {
namespace operators {
class LerpOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lerp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "lerp");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "lerp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "lerp");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto w_dims = ctx->GetInputDim("Weight");
framework::DDim out_dims;
out_dims = GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = GetOutputDims(out_dims, w_dims);
}
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
private:
framework::DDim GetOutputDims(const framework::DDim& s_dims,
const framework::DDim& l_dims) const {
if (s_dims.size() > l_dims.size()) {
return GetOutputDims(l_dims, s_dims);
}
std::vector<int64_t> shapes = framework::vectorize<int64_t>(l_dims);
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
int64_t s = s_dims[i];
int64_t l = l_dims[j];
if (s != l) {
if (l == 1) {
shapes[j] = s;
} else if (s != 1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The shape of tensor a %s:%d must match shape of tensor b "
"%s:%d.",
s_dims.to_str(), i, l_dims.to_str(), j));
}
}
}
return framework::make_ddim(shapes);
}
};
class LerpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of lerp op.");
AddInput("Y", "(Tensor), The input tensor of lerp op.");
AddInput("Weight", "(Tensor, optional), The input tensor of lerp op.");
AddOutput("Out", "(Tensor), The output tensor of lerp op.");
AddComment(R"DOC(
Lerp Operator.
This operator is used to do a linear interpolation of input $X$ and $Y$ with $Weight$.
The equation is:
$$Out = X + Weight * (Y - X)$$
Both the input $X$ and $Y$ can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD information with input $X$.
)DOC");
}
};
class LerpGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
if (ctx->HasOutput(framework::GradVarName("Y"))) {
ctx->SetOutputDim(framework::GradVarName("Y"), ctx->GetInputDim("Y"));
}
}
};
template <typename T>
class LerpOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType("lerp_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput("Weight", this->Input("Weight"));
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(
lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker,
paddle::operators::LerpOpGradMaker<paddle::framework::OpDesc>,
paddle::operators::LerpOpGradMaker<paddle::imperative::OpBase>,
paddle::operators::LerpInplaceInferer);
REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp);
REGISTER_OP_CPU_KERNEL(
lerp,
paddle::operators::LerpKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::LerpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
lerp_grad,
paddle::operators::LerpGradKernel<paddle::platform::CPUDeviceContext,
float>,
paddle::operators::LerpGradKernel<paddle::platform::CPUDeviceContext,
double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/lerp_op.h"
REGISTER_OP_CUDA_KERNEL(
lerp,
paddle::operators::LerpKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::LerpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
lerp_grad,
paddle::operators::LerpGradKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::LerpGradKernel<paddle::platform::CUDADeviceContext,
double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#ifdef _WIN32
#ifndef NOMINMAX
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#endif
#endif
namespace paddle {
namespace operators {
static framework::DDim ExtendDims2Rank(const framework::DDim& in_dims,
int rank) {
if (in_dims.size() == rank) {
return in_dims;
}
std::vector<int64_t> shapes(rank, 1);
for (int i = in_dims.size() - 1, j = rank - 1; i >= 0; --i, --j) {
shapes[j] = in_dims[i];
}
return framework::make_ddim(shapes);
}
template <size_t D>
static void GetBroadcastDims(const framework::DDim& in_dims,
const framework::DDim& out_dims,
Eigen::DSizes<int, D>* bcast_dims) {
for (size_t i = 0; i < D; ++i) {
if (in_dims[i] == out_dims[i]) {
(*bcast_dims)[i] = 1;
} else {
(*bcast_dims)[i] = std::max(in_dims[i], out_dims[i]);
}
}
}
template <typename DeviceContext, typename T, size_t D>
static void LerpFunction(const framework::ExecutionContext& ctx) {
auto x = ctx.Input<framework::Tensor>("X");
auto y = ctx.Input<framework::Tensor>("Y");
auto w = ctx.Input<framework::Tensor>("Weight");
auto out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto out_dims = out->dims();
auto x_dims = ExtendDims2Rank(x->dims(), D);
auto y_dims = ExtendDims2Rank(y->dims(), D);
auto w_dims = ExtendDims2Rank(w->dims(), D);
Eigen::DSizes<int, D> x_bcast_dims;
Eigen::DSizes<int, D> y_bcast_dims;
Eigen::DSizes<int, D> w_bcast_dims;
GetBroadcastDims<D>(x_dims, out_dims, &x_bcast_dims);
GetBroadcastDims<D>(y_dims, out_dims, &y_bcast_dims);
GetBroadcastDims<D>(w_dims, out_dims, &w_bcast_dims);
auto eigen_x = framework::EigenTensor<T, D>::From(*x, x_dims);
auto eigen_y = framework::EigenTensor<T, D>::From(*y, y_dims);
auto eigen_w = framework::EigenTensor<T, D>::From(*w, w_dims);
auto eigen_out = framework::EigenTensor<T, D>::From(*out);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
eigen_out.device(place) =
eigen_x.broadcast(x_bcast_dims) +
eigen_w.broadcast(w_bcast_dims) *
(eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims));
}
template <typename DeviceContext, typename T, size_t D>
static void LerpGradFunction(const framework::ExecutionContext& ctx) {
auto w = ctx.Input<framework::Tensor>("Weight");
auto dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
auto dout_dims = dout->dims();
auto dx_dims = ExtendDims2Rank(dx->dims(), D);
auto dy_dims = ExtendDims2Rank(dy->dims(), D);
auto w_dims = ExtendDims2Rank(w->dims(), D);
Eigen::DSizes<int, D> dx_bcast_dims;
Eigen::DSizes<int, D> dy_bcast_dims;
Eigen::DSizes<int, D> w_bcast_dims;
GetBroadcastDims<D>(dx_dims, dout_dims, &dx_bcast_dims);
GetBroadcastDims<D>(dy_dims, dout_dims, &dy_bcast_dims);
GetBroadcastDims<D>(w_dims, dout_dims, &w_bcast_dims);
auto eigen_w = framework::EigenTensor<T, D>::From(*w, w_dims);
auto eigen_dout = framework::EigenTensor<T, D>::From(*dout);
Eigen::DSizes<int, D * 2> dx_reshape_dims;
Eigen::DSizes<int, D * 2> dy_reshape_dims;
Eigen::DSizes<int, D> reduce_dims;
for (int i = 0; i < dout_dims.size(); ++i) {
dx_reshape_dims[2 * i] = dx_bcast_dims[i];
dx_reshape_dims[2 * i + 1] = dx_dims[i];
dy_reshape_dims[2 * i] = dy_bcast_dims[i];
dy_reshape_dims[2 * i + 1] = dy_dims[i];
reduce_dims[i] = 2 * i;
}
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
auto eigen_dx = framework::EigenTensor<T, D>::From(*dx, dx_dims);
auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout;
eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims)
.sum(reduce_dims)
.reshape(eigen_dx.dimensions());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
auto eigen_dy = framework::EigenTensor<T, D>::From(*dy, dy_dims);
auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout;
eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims)
.sum(reduce_dims)
.reshape(eigen_dy.dimensions());
}
}
template <typename DeviceContext, typename T>
class LerpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int rank = ctx.Output<framework::Tensor>("Out")->dims().size();
PADDLE_ENFORCE_GE(
rank, 1,
platform::errors::InvalidArgument(
"The number of dimensions for LerpOp must be "
"greater than or equal to 1, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, 6, platform::errors::InvalidArgument(
"The number of dimensions for LerpOp must be "
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) {
case 1:
LerpFunction<DeviceContext, T, 1>(ctx);
break;
case 2:
LerpFunction<DeviceContext, T, 2>(ctx);
break;
case 3:
LerpFunction<DeviceContext, T, 3>(ctx);
break;
case 4:
LerpFunction<DeviceContext, T, 4>(ctx);
break;
case 5:
LerpFunction<DeviceContext, T, 5>(ctx);
break;
case 6:
LerpFunction<DeviceContext, T, 6>(ctx);
break;
}
}
};
template <typename DeviceContext, typename T>
class LerpGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int rank = ctx.Input<framework::Tensor>(framework::GradVarName("Out"))
->dims()
.size();
PADDLE_ENFORCE_GE(
rank, 1,
platform::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"greater than or equal to 1, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, 6, platform::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) {
case 1:
LerpGradFunction<DeviceContext, T, 1>(ctx);
break;
case 2:
LerpGradFunction<DeviceContext, T, 2>(ctx);
break;
case 3:
LerpGradFunction<DeviceContext, T, 3>(ctx);
break;
case 4:
LerpGradFunction<DeviceContext, T, 4>(ctx);
break;
case 5:
LerpGradFunction<DeviceContext, T, 5>(ctx);
break;
case 6:
LerpGradFunction<DeviceContext, T, 6>(ctx);
break;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -224,6 +224,7 @@ from .tensor.math import trunc # noqa: F401
from .tensor.math import digamma # noqa: F401
from .tensor.math import neg # noqa: F401
from .tensor.math import lgamma # noqa: F401
from .tensor.math import lerp # noqa: F401
from .tensor.math import rad2deg # noqa: F401
from .tensor.math import deg2rad # noqa: F401
from .tensor.math import diff # noqa: F401
......@@ -469,6 +470,7 @@ __all__ = [ # noqa
'conj',
'neg',
'lgamma',
'lerp',
'square',
'divide',
'ceil',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
paddle.enable_static()
np.random.seed(0)
class TestLerp(OpTest):
def setUp(self):
self.op_type = "lerp"
self.init_dtype()
self.init_shape()
x = np.arange(1., 101.).astype(self.dtype).reshape(self.shape)
y = np.full(100, 10.).astype(self.dtype).reshape(self.shape)
w = np.asarray([0.5]).astype(self.dtype)
self.inputs = {'X': x, 'Y': y, 'Weight': w}
self.outputs = {'Out': x + w * (y - x)}
def init_dtype(self):
self.dtype = np.float64
def init_shape(self):
self.shape = [100]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out')
class TestLerpWithDim2(TestLerp):
def init_shape(self):
self.shape = [2, 50]
class TestLerpWithDim3(TestLerp):
def init_shape(self):
self.shape = [2, 2, 25]
class TestLerpWithDim4(TestLerp):
def init_shape(self):
self.shape = [2, 2, 5, 5]
class TestLerpWithDim5(TestLerp):
def init_shape(self):
self.shape = [2, 1, 2, 5, 5]
class TestLerpWithDim6(TestLerp):
def init_shape(self):
self.shape = [2, 1, 2, 5, 1, 5]
class TestLerpAPI(unittest.TestCase):
def init_dtype(self):
self.dtype = 'float32'
def setUp(self):
self.init_dtype()
self.x = np.arange(1., 5.).astype(self.dtype)
self.y = np.full(4, 10.).astype(self.dtype)
self.w = np.asarray([0.75]).astype(self.dtype)
self.res_ref = self.x + self.w * (self.y - self.x)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_static_api(self):
paddle.enable_static()
def run(place):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('x', [1, 4], dtype=self.dtype)
y = paddle.fluid.data('y', [1, 4], dtype=self.dtype)
w = paddle.fluid.data('w', [1], dtype=self.dtype)
out = paddle.lerp(x, y, w)
exe = paddle.static.Executor(place)
res = exe.run(feed={
'x': self.x.reshape([1, 4]),
'y': self.y.reshape([1, 4]),
'w': self.w
})
for r in res:
self.assertEqual(np.allclose(self.res_ref, r), True)
for place in self.place:
run(place)
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
w = paddle.to_tensor(np.full(4, 0.75).astype(self.dtype))
out = paddle.lerp(x, y, w)
self.assertEqual(np.allclose(self.res_ref, out.numpy()), True)
paddle.enable_static()
for place in self.place:
run(place)
def test_inplace_api(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
x.lerp_(y, 0.75)
self.assertEqual(np.allclose(self.res_ref, x.numpy()), True)
paddle.enable_static()
for place in self.place:
run(place)
def test_inplace_api_exception(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
w = paddle.to_tensor([0.75, 0.75], dtype=self.dtype)
with self.assertRaises(ValueError):
x.lerp_(y, w)
paddle.enable_static()
for place in self.place:
run(place)
def test_x_broadcast_y(self):
paddle.disable_static()
x = np.arange(1., 21.).astype(self.dtype).reshape([2, 2, 5])
y = np.full(30, 10.).astype(self.dtype).reshape([3, 2, 1, 5])
out = paddle.lerp(paddle.to_tensor(x), paddle.to_tensor(y), 0.5)
res_ref = x + 0.5 * (y - x)
self.assertEqual(np.allclose(res_ref, out.numpy()), True)
paddle.enable_static()
def test_x_y_broadcast_w(self):
paddle.disable_static()
x = np.arange(11., 21.).astype(self.dtype).reshape([2, 5])
y = np.full(20, 7.5).astype(self.dtype).reshape([2, 2, 5])
w = np.full(40, 0.225).astype(self.dtype).reshape([2, 2, 2, 5])
out = paddle.lerp(
paddle.to_tensor(x), paddle.to_tensor(y), paddle.to_tensor(w))
res_ref = x + w * (y - x)
self.assertEqual(np.allclose(res_ref, out.numpy()), True)
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -190,6 +190,8 @@ from .math import digamma # noqa: F401
from .math import neg # noqa: F401
from .math import lgamma # noqa: F401
from .math import diagonal # noqa: F401
from .math import lerp # noqa: F401
from .math import lerp_ # noqa: F401
from .math import rad2deg # noqa: F401
from .math import deg2rad # noqa: F401
from .math import diff # noqa: F401
......@@ -408,6 +410,8 @@ tensor_method_func = [ #noqa
'solve',
'triangular_solve',
'diff',
'lerp',
'lerp_',
'angle',
]
......
......@@ -2614,6 +2614,68 @@ def atan2(x, y, name=None):
type='atan2', inputs=inputs, outputs={'Out': out})
return out
def lerp(x, y, weight, name=None):
r"""
Does a linear interpolation between x and y based on weight.
Equation:
.. math::
lerp(x, y, weight) = x + weight * (y - x).
Args:
x (Tensor): An N-D Tensor, the data type is float32, float64.
y (Tensor): An N-D Tensor, the data type is float32, float64.
weight (float|Tensor): the weight for the interpolation formula.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
out (Tensor): An N-D Tensor, the shape and data type is the same with input.
Example:
.. code-block:: python
import paddle
x = paddle.arange(1., 5., dtype='float32')
y = paddle.empty([4], dtype='float32')
y.fill_(10.)
out = paddle.lerp(start, end, 0.5)
# out: [5.5., 6., 6.5, 7.]
"""
if in_dygraph_mode():
check_type(weight, 'weight', (float, paddle.Tensor, Variable), 'lerp')
if isinstance(weight, float):
weight = paddle.to_tensor(weight, dtype=x.dtype)
return _C_ops.lerp(x, y, weight)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lerp')
check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'lerp')
check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], 'lerp')
helper = LayerHelper('lerp', **locals())
inputs = {'X': x, 'Y': y, 'Weight': weight}
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='lerp', inputs=inputs, outputs={'Out': out})
return out
@inplace_apis_in_dygraph_only
def lerp_(x, y, weight, name=None):
r"""
Inplace version of ``lerp`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_lerp`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
check_type(weight, 'weight', (float, paddle.Tensor, Variable), 'lerp')
if isinstance(weight, float):
weight = paddle.to_tensor([weight], dtype=x.dtype)
elif isinstance(weight, (paddle.Tensor, Variable)):
out_shape = broadcast_shape(out_shape, weight.shape)
if out_shape != x.shape:
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
return _C_ops.lerp_(x, y, weight)
def rad2deg(x, name=None):
"""
Convert each of the elements of input x from angles in radians to degrees.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册