提交 c7c16a71 编写于 作者: L liuruilong

Merge remote-tracking branch 'upstream/develop' into develop

......@@ -32,21 +32,15 @@ struct OpInfo {
}
};
template <typename Dtype>
class OpInfoMap;
template <typename Dtype>
static OpInfoMap<Dtype> *g_op_info_map = nullptr;
template <typename Dtype>
class OpInfoMap {
public:
static OpInfoMap &Instance() {
LOG(paddle_mobile::kLOG_DEBUG1) << " TODO: fix bug";
if (g_op_info_map<Dtype> == nullptr) {
g_op_info_map<Dtype> = new OpInfoMap();
static OpInfoMap<Dtype> *Instance() {
static OpInfoMap<Dtype> *s_instance = nullptr;
if (s_instance == nullptr) {
s_instance = new OpInfoMap();
}
return *g_op_info_map<Dtype>;
return s_instance;
}
bool Has(const std::string &op_type) const {
......
......@@ -35,7 +35,7 @@ class OperatorRegistrarRecursive;
template <typename Dtype, typename... ARGS>
struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const std::string& op_type) {
if (OpInfoMap<Dtype>::Instance().Has(op_type)) {
if (OpInfoMap<Dtype>::Instance()->Has(op_type)) {
LOG(paddle_mobile::kLOG_DEBUG1)
<< op_type << " is registered more than once.";
return;
......@@ -47,7 +47,7 @@ struct OperatorRegistrar : public Registrar {
}
OpInfo<Dtype> info;
OperatorRegistrarRecursive<Dtype, 0, false, ARGS...>(op_type, &info);
OpInfoMap<Dtype>::Instance().Insert(op_type, info);
OpInfoMap<Dtype>::Instance()->Insert(op_type, info);
}
};
......@@ -95,10 +95,10 @@ class OpRegistry {
LOG(paddle_mobile::kLOG_DEBUG1) << " output size: " << outputs.size();
LOG(paddle_mobile::kLOG_DEBUG1) << " attr size: " << attrs.size();
LOG(paddle_mobile::kLOG_DEBUG1)
<< " OpInfoMap size: " << OpInfoMap<Dtype>::Instance().map().size();
<< " OpInfoMap size: " << OpInfoMap<Dtype>::Instance()->map().size();
LOG(paddle_mobile::kLOG_DEBUG1) << " has type: " << type << " "
<< OpInfoMap<Dtype>::Instance().Has(type);
auto& info = OpInfoMap<Dtype>::Instance().Get(type);
<< OpInfoMap<Dtype>::Instance()->Has(type);
auto& info = OpInfoMap<Dtype>::Instance()->Get(type);
auto op = info.Creator()(type, inputs, outputs, attrs, scope);
return std::shared_ptr<OperatorBase<Dtype>>(op);
}
......
......@@ -132,13 +132,6 @@ class Tensor {
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
inline void *mutable_data() {
// PADDLE_ENFORCE(this->holder_ != nullptr,
// "Cannot invoke mutable data if current hold
// nothing.");
return mutable_data(holder_->type());
}
/**
* @brief Return a pointer to mutable memory block.
*
......
......@@ -35,14 +35,9 @@ void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
LOG(kLOG_DEBUG) << param;
const Tensor *input = param.Input();
// The filter will be reshaped in the calculations,
// so here use an assignment operation,
// that avoids modifying the variable in the Scope.
Tensor filter = *param.Filter();
Tensor *output = param.Output();
// output->mutable_data<T>(context.GetPlace());
output->mutable_data<float>();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
......@@ -53,17 +48,9 @@ void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
const int batch_size = static_cast<int>(input->dims()[0]);
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h,
// k_w}
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h,
// o_w}
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h,
// k_w, o_d,
// o_h, o_w}
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
......@@ -73,24 +60,19 @@ void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w,
// o_d *
// o_h * o_w)
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
DLOG << " col_shape = " << col_shape;
DLOG << " col_matrix_shape = " << col_matrix_shape;
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
......@@ -98,6 +80,7 @@ void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
DLOG << " filter.deims() = " << filter.dims();
framework::DDim output_matrix_shape = {
output->dims()[1],
......@@ -110,8 +93,6 @@ void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
// auto& dev_ctx = context.template
// device_context<DeviceContext>();
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
......@@ -137,6 +118,9 @@ void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
DLOG << " out_slice " << out_slice.dims();
DLOG << " filter_slice " << filter_slice.dims();
DLOG << " col_matrix " << col_matrix.dims();
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../sigmoid_kernel.h"
#if __ARM_NEON
#include "../../math/math_func_neon.h"
#endif
namespace paddle_mobile {
namespace operators {
using framework::DDim;
using framework::Tensor;
void sigmoid(const Tensor *X, Tensor *Y) {
#if __ARM_NEON
DLOG << "step1";
const float *input = X->data<float>();
DLOG << "step11";
float *output = Y->mutable_data<float>();
DLOG << "step2";
const DDim &dDim = X->dims();
DLOG << "step3";
int axis_index = 1;
if (dDim.size() < 4) {
axis_index = 0;
}
DLOG << "step4";
DDim outer_ddim =
paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1);
DDim inner_ddim =
paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size());
DLOG << "step5";
int out_size = paddle_mobile::framework::product(outer_ddim);
int inner_size = paddle_mobile::framework::product(inner_ddim);
DLOG << "step6";
#pragma omp parallel for
DLOG << "outsize=" << out_size;
DLOG << "innersize=" << inner_size;
for (int i = 0; i < out_size; ++i) {
const float *input_outer_ptr = input + i * inner_size;
float *output_outer_ptr = output + i * inner_size;
int nn = inner_size >> 2;
int remain = inner_size - (nn << 2);
float32x4_t _one = vdupq_n_f32(1.f);
for (; nn > 0; nn--) {
float32x4_t data = vld1q_f32(input_outer_ptr);
data = vnegq_f32(data);
data = exp_ps(data);
data = vaddq_f32(data, _one);
float32x4_t out_data = vrecpeq_f32(data);
out_data = vmulq_f32(vrecpsq_f32(data, out_data), out_data);
vst1q_f32(output_outer_ptr, out_data);
input_outer_ptr += 4;
output_outer_ptr += 4;
}
for (; remain > 0; remain--) {
*output_outer_ptr = 1.f / (1.f + exp(-*input_outer_ptr));
output_outer_ptr++;
input_outer_ptr++;
}
}
#endif
}
template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam &param) const {
const Tensor *in_x = param.InputX();
Tensor *out = param.Out();
auto x_dims = in_x->dims();
out->Resize(x_dims);
sigmoid(in_x, out);
}
template class SigmoidKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
void sigmoid(const Tensor* X, Tensor* Y);
template <typename DeviceType, typename T>
class SigmoidKernel : public OpKernelBase<DeviceType, SigmoidParam> {
public:
void Compute(const SigmoidParam& param) const override;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -21,6 +21,8 @@ namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
void simoid(Tensor *X, Tensor *Y);
template <typename DeviceType, typename T>
class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> {
public:
......
......@@ -11,11 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/softmax.h"
#include "common/types.h"
#if __ARM_NEON
#include <math.h>
#include <algorithm>
#include "operators/math/math_func_neon.h"
#endif
......@@ -108,7 +108,7 @@ class SoftmaxFuntor<CPU, T> {
// sum exp
sum(exp_sub_max, sumptr, inner_size, out_size);
// div
auto *out_ptr = static_cast<float *>(Y->mutable_data());
auto *out_ptr = Y->mutable_data<float>();
for (int l = 0; l < out_size; ++l) {
const float *input_outer_ptr = exp_sub_max + l * inner_size;
float *output_outer_ptr = out_ptr + l * inner_size;
......
......@@ -542,6 +542,22 @@ class SoftmaxParam : public OpParam {
Tensor *input_x_;
Tensor *out_;
};
class SigmoidParam : public OpParam {
public:
SigmoidParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
out_ = OutFrom<framework::Tensor>(outputs, scope);
}
const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; }
private:
Tensor *input_x_;
Tensor *out_;
};
class MultiClassNMSParam : public OpParam {
public:
MultiClassNMSParam(const VariableNameMap &inputs,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/sigmoid_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void SigmoidOp<DeviceType, T>::InferShape() const {
param_.Out()->Resize(param_.InputX()->dims());
}
template class SigmoidOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
USE_OP(sigmoid);
REGISTER_OPERATOR(sigmoid, ops::SigmoidOp);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <framework/operator.h>
#include <operators/op_param.h>
#include <string>
#include "operators/kernel/sigmoid_kernel.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class SigmoidOp : public framework::OperatorWithKernel<DeviceType> {
public:
SigmoidOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, *scope) {}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override;
void Run() const {
operators::SigmoidKernel<DeviceType, T> kernel;
kernel.Compute(param_);
this->ClearVariables({"X"});
}
private:
SigmoidParam param_;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -82,3 +82,7 @@ target_link_libraries(test-enforce paddle-mobile)
# gen test
ADD_EXECUTABLE(test-googlenet net/test_googlenet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-googlenet paddle-mobile)
# gen test
ADD_EXECUTABLE(test-sigmoid operators/test_sigmoid_op.cpp test_include.h)
target_link_libraries(test-sigmoid paddle-mobile)
......@@ -17,11 +17,14 @@ limitations under the License. */
#include <string>
#include <vector>
#include "./io.h"
#include "common/log.h"
#include "io.h"
#include "framework/op_registry.h"
#include "operators/conv_op.h"
#include "operators/pool_op.h"
#include "operators/relu_op.h"
#include "operators/reshape_op.h"
#include "operators/sigmoid_op.h"
#include "operators/softmax_op.h"
#include "operators/transpose_op.h"
......@@ -57,9 +60,13 @@ class Executor4Test : public Executor<DeviceType> {
for (std::shared_ptr<OpDesc> op : ops) {
if (op->Type() == op_type) {
std::shared_ptr<OpType> op_ptr = std::make_shared<OpType>(
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
this->program_.scope);
/// test first meeting op in program
std::shared_ptr<paddle_mobile::framework::OperatorBase<DeviceType>>
op_ptr = paddle_mobile::framework::OpRegistry<
paddle_mobile::CPU>::CreateOp(op->Type(), op->GetInputs(),
op->GetOutputs(),
op->GetAttrMap(),
this->program_.scope);
this->ops_of_block_[*block_desc.get()].push_back(op_ptr);
break;
}
......
......@@ -29,6 +29,9 @@ int main() {
paddle_mobile::framework::Tensor input;
GetInput<float>(g_test_image_1x3x224x224, &input, {1, 3, 224, 224});
// // use SetupTensor if not has local input image .
// SetupTensor<float>(&input, {1, 3, 224, 224}, static_cast<float>(0),
// static_cast<float>(1));
auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 112, 112});
auto output = executor.predict(input, "data", "conv2d_0.tmp_0", out_ddim);
......
......@@ -111,7 +111,7 @@ int main() {
DLOG << "begin to run ElementAddOp Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program =
loader.Load(std::string("../../test/models/"
loader.Load(std::string("../models/"
"image_classification_resnet.inference.model"));
/// input x (1,3,224,224)
......
......@@ -12,108 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../executor_for_test.h"
#include "../test_include.h"
#include "operators/relu_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestReluOp {
public:
explicit TestReluOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (auto block_desc : blocks) {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (auto op : ops) {
if (op->Type() == "relu" &&
op->Input("X")[0] == "batch_norm_34.tmp_2") {
DLOG << "in";
std::shared_ptr<operators::ReluOp<Dtype, float>> test_op =
std::make_shared<operators::ReluOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(test_op);
}
}
}
}
std::shared_ptr<Tensor> predict(const Tensor &t1) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("batch_norm_34.tmp_2");
auto tensor_x1 = x1_feed_value->GetMutable<Tensor>();
tensor_x1->ShareDataWith(t1);
Variable *output = scope->Var("batch_norm_34.tmp_3");
auto *output_tensor = output->GetMutable<Tensor>();
output_tensor->mutable_data<float>({1, 2, 3, 4});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict(t1, 0);
return out_tensor;
// return outvars_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict(const Tensor &t1, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestReluOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run Relu Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string("../../test/models/mobilenet+ssd"));
// ../models/image_classification_resnet.inference.model
auto program = loader.Load(g_mobilenet_ssd);
/// input x (1,3,300,300)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {1, 2, 3, 4}, static_cast<float>(-1),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail");
paddle_mobile::framework::TestReluOp<paddle_mobile::CPU> testReluOp(program);
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::ReluOp<paddle_mobile::CPU, float>>
executor(program, "relu");
paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {1, 2, 3, 4}, static_cast<float>(-1),
static_cast<float>(1));
auto output = testReluOp.predict(inputx1);
auto *output_ptr = output->data<float>();
auto out_ddim = paddle_mobile::framework::make_ddim({1, 2, 3, 4});
auto output = executor.predict(input, "batch_norm_0.tmp_2",
"batch_norm_0.tmp_3", out_ddim);
for (int i = 0; i < output->numel(); i++) {
DLOG << output_ptr[i];
auto output_ptr = output->data<float>();
for (int j = 0; j < output->numel(); ++j) {
DLOG << " value of output: " << output_ptr[j];
}
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../../src/operators/kernel/sigmoid_kernel.h"
#include "../test_helper.h"
#include "./io.h"
int main() {
paddle_mobile::framework::Tensor input;
paddle_mobile::framework::Tensor output;
DLOG << 1;
SetupTensor<float>(&input, {1, 4, 60, 60}, static_cast<float>(0),
static_cast<float>(1));
DLOG << 2;
auto out_ddim = paddle_mobile::framework::make_ddim({1, 4, 60, 60});
output.Resize(out_ddim);
DLOG << 3;
paddle_mobile::operators::sigmoid(&input, &output);
DLOG << 4;
auto *output_ptr = output.data<float>();
for (int j = 0; j < output.numel(); ++j) {
DLOG << " value of output: " << output_ptr[j];
}
DLOG << 5;
return 0;
}
......@@ -23,7 +23,7 @@ limitations under the License. */
static const std::string g_googlenet = "../models/googlenet";
static const std::string g_mobilenet = "../models/mobilenet";
static const std::string g_mobilenet_ssd = "../models/mobilenet";
static const std::string g_mobilenet_ssd = "../models/mobilenet+ssd";
static const std::string g_squeezenet = "../models/squeezenet";
static const std::string g_resnet =
"../models/image_classification_resnet.inference.model";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册