未验证 提交 0f295c87 编写于 作者: Y yiicy 提交者: GitHub

[OPENCL] add pad2d image kernel and ut, test=develop (#3143)

add pad2d image kernel and ut
上级 08a3ed12
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void pad2d_constant(
__read_only image2d_t input, __write_only image2d_t output,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_h0, const int pad_h1,
const int pad_w0, const int pad_w1,
const float pad_value) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const int out_n = out_nh / out_height;
const int out_h = out_nh % out_height;
int2 output_pos = (int2)(mad24(out_c, out_width, out_w), out_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int x = out_w - pad_w0;
int y = out_h - pad_h0;
if (x < 0 || y < 0 || x >= in_width || y >= in_height) {
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, (CL_DTYPE4)(pad_value));
} else {
int2 coor = (int2)(out_c * in_width + x, out_n * in_height + y);
CL_DTYPE4 pixel = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coor);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, pixel);
}
}
__kernel void pad2d_reflect(
__read_only image2d_t input, __write_only image2d_t output,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_h0, const int pad_h1,
const int pad_w0, const int pad_w1,
const float pad_value) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const int out_n = out_nh / out_height;
const int out_h = out_nh % out_height;
int2 output_pos = (int2)(mad24(out_c, out_width, out_w), out_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int x = out_w - pad_w0;
int y = out_h - pad_h0;
x = abs(x);
y = abs(y);
x = x < in_width ? x : 2 * in_width - 2 - x;
y = y < in_height ? y : 2 * in_height - 2 - y;
int2 coor = (int2)(out_c * in_width + x, out_n * in_height + y);
CL_DTYPE4 pixel = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coor);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, pixel);
}
__kernel void pad2d_edge(
__read_only image2d_t input, __write_only image2d_t output,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_h0, const int pad_h1,
const int pad_w0, const int pad_w1,
const float pad_value) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const int out_n = out_nh / out_height;
const int out_h = out_nh % out_height;
int2 output_pos = (int2)(mad24(out_c, out_width, out_w), out_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int x = out_w - pad_w0;
int y = out_h - pad_h0;
x = x > 0 ? x : 0;
x = x < in_width ? x : in_width - 1;
y = y > 0 ? y : 0;
y = y < in_height ? y : in_height - 1;
int2 coor = (int2)(out_c * in_width + x, out_n * in_height + y);
CL_DTYPE4 pixel = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coor);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, pixel);
}
......@@ -32,6 +32,7 @@ add_kernel(bilinear_interp_opencl OPENCL basic SRCS bilinear_interp_image_comput
add_kernel(slice_opencl OPENCL basic SRCS slice_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(instance_norm_opencl OPENCL basic SRCS instance_norm_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(dropout_opencl OPENCL basic SRCS dropout_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(pad2d_opencl OPENCL basic SRCS pad2d_image_compute.cc DEPS ${cl_kernel_deps})
# extra
# wait to add ...
......@@ -92,7 +93,10 @@ lite_cc_test(test_instance_norm_image_opencl SRCS instance_norm_image_compute_te
DEPS instance_norm_opencl op_registry program context)
lite_cc_test(test_dropout_image_opencl SRCS dropout_image_compute_test.cc
DEPS dropout_opencl op_registry program context)
DEPS dropout_opencl op_registry program context)
lite_cc_test(test_pad2d_image_opencl SRCS pad2d_image_compute_test.cc
DEPS pad2d_opencl layout_opencl op_registry program context)
######################
# buffer kernel #
######################
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "lite/backends/opencl/cl_half.h"
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
class Pad2dCompute : public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::Pad2dParam;
std::string doc() const override {
return "Pad2d using cl::Image2D(ImageDefault/RGBA), kFP16";
}
void PrepareForRun() override {
pad2d_param_ = param_.get_mutable<param_t>();
if (pad2d_param_->mode == "constant") {
kernel_func_name_ = "pad2d_constant";
} else if (pad2d_param_->mode == "reflect") {
kernel_func_name_ = "pad2d_reflect";
} else if (pad2d_param_->mode == "edge") {
kernel_func_name_ = "pad2d_edge";
} else {
LOG(FATAL) << "Unknown mode type";
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/pad2d_kernel.cl", build_options_);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
}
void Run() override {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* x = pad2d_param_->X;
auto* out = pad2d_param_->Out;
auto out_dims = out->dims();
auto in_dims = x->dims();
int in_h = in_dims[2];
int in_w = in_dims[3];
int out_h = out_dims[2];
int out_w = out_dims[3];
VLOG(4) << "x->target():" << TargetToStr(x->target());
VLOG(4) << "out->target():" << TargetToStr(out->target());
VLOG(4) << "x->dims():" << in_dims;
VLOG(4) << "out->dims():" << out_dims;
auto out_image_shape = InitImageDimInfoWith(out_dims);
auto* x_img = x->data<half_t, cl::Image2D>();
auto* out_img = out->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
VLOG(4) << "out_image_shape[w,h]: " << out_image_shape["width"] << " "
<< out_image_shape["height"];
VLOG(4) << "in_h: " << in_h << ", in_w: " << in_w;
VLOG(4) << "out_h: " << out_h << ", out_w: " << out_w;
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0;
auto default_work_size =
DefaultWorkSize(out_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
VLOG(4) << "default_work_size: " << default_work_size[0] << ", "
<< default_work_size[1] << ", " << default_work_size[2];
int pad_h0 = pad2d_param_->paddings[0];
int pad_h1 = pad2d_param_->paddings[1];
int pad_w0 = pad2d_param_->paddings[2];
int pad_w1 = pad2d_param_->paddings[3];
float pad_value = pad2d_param_->pad_value;
cl_int status = kernel.setArg(arg_idx++, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, in_h);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, in_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_h);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, out_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, pad_h0);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, pad_h1);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, pad_w0);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, pad_w1);
CL_CHECK_FATAL(status);
status = kernel.setArg(arg_idx++, pad_value);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(default_work_size[0]),
static_cast<cl::size_type>(default_work_size[1]),
static_cast<cl::size_type>(default_work_size[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " "
<< global_work_size[1] << " " << global_work_size[2];
}
protected:
param_t* pad2d_param_{nullptr};
std::string kernel_func_name_{};
std::string build_options_{"-DCL_DTYPE_half"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
namespace ocl = paddle::lite::kernels::opencl;
REGISTER_LITE_KERNEL(
pad2d, kOpenCL, kFP16, kImageDefault, ocl::Pad2dCompute, ImageDefault)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <random>
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/opencl/image_helper.h"
namespace paddle {
namespace lite {
void pad2d_ref(const float *x_data,
Tensor *y,
std::string mode,
int pad_h0,
int pad_h1,
int pad_w0,
int pad_w1,
float pad_value) {
auto *out_data = y->mutable_data<float>();
auto output_dims = y->dims();
int n = output_dims[0];
int c = output_dims[1];
int h = output_dims[2];
int w = output_dims[3];
int pad_mode;
if (mode == "constant") {
pad_mode = 0;
} else if (mode == "reflect") {
pad_mode = 2;
} else if (mode == "edge") {
pad_mode = 1;
} else {
LOG(FATAL) << "Unknown mode type";
}
int in_w = w - pad_w0 - pad_w1;
int in_h = h - pad_h0 - pad_h1;
int spatial_size_out = w * h;
int spatial_size_in = in_w * in_h;
#pragma omp parallel for
for (int i = 0; i < n * c; ++i) {
const float *din_batch = x_data + i * spatial_size_in;
float *dout_batch = out_data + i * spatial_size_out;
int in_y = 0;
int in_x = 0;
for (int y = 0; y < h; ++y) {
for (int x = 0; x < w; ++x) {
switch (pad_mode) {
case 0:
in_y = y - pad_h0;
in_x = x - pad_w0;
dout_batch[y * w + x] =
(in_x >= 0 && in_x < in_w) && (in_y >= 0 && in_y < in_h)
? din_batch[in_y * in_w + in_x]
: pad_value;
break;
case 1:
in_x = std::min(std::max(pad_w0, x), in_w + pad_w0 - 1) - pad_w0;
in_y = std::min(std::max(pad_h0, y), in_h + pad_h0 - 1) - pad_h0;
dout_batch[y * w + x] = din_batch[in_y * in_w + in_x];
break;
case 2:
in_y = y - pad_h0;
in_x = x - pad_w0;
in_y = std::max(in_y, -in_y);
in_y = std::min(in_y, 2 * in_h - in_y - 2);
in_x = std::max(in_x, -in_x);
in_x = std::min(in_x, 2 * in_w - in_x - 2);
dout_batch[y * w + x] = din_batch[in_y * in_w + in_x];
break;
default:
LOG(ERROR) << "ERROR: unknown pad mode:" << pad_mode;
}
}
}
}
}
#define LOOP_TEST
// #define PRINT_RESULT
TEST(pad2d_image2d, compute) {
LOG(INFO) << "main steps of test: host -> layout(buf2img) -> "
"pad2d(img) -> "
"layout(img2buf) "
"-> host";
#ifdef LOOP_TEST
for (int n : {1, 3}) {
for (auto c : {1, 3}) {
for (int h : {12, 112}) {
for (int w : {12, 112}) {
for (int pad_h0 : {0, 1, 2}) {
for (int pad_h1 : {0, 1, 2}) {
for (int pad_w0 : {0, 1, 2}) {
for (int pad_w1 : {0, 1, 2}) {
for (float pad_value : {10.f}) {
for (std::string pad_mode :
{"constant", "reflect", "edge"}) {
#else
const int n = 1;
const int c = 3;
const int h = 12;
const int w = 112;
const int pad_h0 = 1;
const int pad_h1 = 2;
const int pad_w0 = 1;
const int pad_w1 = 2;
const float pad_value = 10.f;
std::string pad_mode = "reflect";
#endif // LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " "
<< c << " " << h << " " << w;
LOG(INFO) << "======== pad_h0: " << pad_h0
<< ", pad_h1: " << pad_h1
<< ", pad_w0: " << pad_w0
<< ", pad_w1: " << pad_w1
<< ", pad_value: " << pad_value
<< ", pad_mode: " << pad_mode;
// set layout kernels
auto buf_to_img_kernels = KernelRegistry::Global().Create(
"layout",
TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault));
auto img_to_buf_kernels =
KernelRegistry::Global().Create("layout",
TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kNCHW));
auto pad2d_img_kernels = KernelRegistry::Global().Create(
"pad2d",
TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(buf_to_img_kernels.empty());
ASSERT_FALSE(pad2d_img_kernels.empty());
auto buf_to_img_kernel =
std::move(buf_to_img_kernels.front());
auto img_to_buf_kernel =
std::move(img_to_buf_kernels.front());
auto pad2d_img_kernel =
std::move(pad2d_img_kernels.front());
LOG(INFO) << "get 1st kernel: "
<< buf_to_img_kernel->doc();
LOG(INFO) << "get 2nd kernel: "
<< img_to_buf_kernel->doc();
LOG(INFO) << "get 3rd kernel: "
<< pad2d_img_kernel->doc();
// set tensors about op param
LOG(INFO) << "set tensors about op param";
// layout(buf->img): x -> pad2d_in
// pad2d(img): pad2d_in -> pad2d_out
// layout(img->buf): pad2d_out -> y
lite::Tensor x, y, pad2d_in, pad2d_out, y_ref;
operators::LayoutParam BufferToImageParam;
operators::LayoutParam ImageToBufferParam;
BufferToImageParam.x = &x;
BufferToImageParam.y = &pad2d_in;
ImageToBufferParam.x = &pad2d_out;
ImageToBufferParam.y = &y;
operators::Pad2dParam Pad2dParam;
Pad2dParam.X = &pad2d_in;
Pad2dParam.Out = &pad2d_out;
Pad2dParam.paddings = {pad_h0, pad_h1, pad_w0, pad_w1};
Pad2dParam.pad_value = pad_value;
Pad2dParam.mode = pad_mode;
int64_t out_h = h + pad_h0 + pad_h1;
int64_t out_w = w + pad_w0 + pad_w1;
const DDim x_dim =
DDim(std::vector<DDim::value_type>{n, c, h, w});
const DDim y_dim = DDim(
std::vector<DDim::value_type>{n, c, out_h, out_w});
x.Resize(x_dim);
y.Resize(y_dim);
pad2d_in.Resize(x_dim);
pad2d_out.Resize(y_dim);
y_ref.Resize(y_dim);
auto pad2d_image2d_shape =
paddle::lite::kernels::opencl::InitImageDimInfoWith(
x_dim);
// initialize tensors
LOG(INFO) << "initialize tensors";
auto *x_data =
x.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data =
y.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
auto *y_data_ref =
y_ref.mutable_data<float>(TARGET(kARM));
auto *mapped_x =
static_cast<float *>(TargetWrapperCL::Map(
x_data, 0, sizeof(float) * x_dim.production()));
auto *mapped_y =
static_cast<float *>(TargetWrapperCL::Map(
y_data, 0, sizeof(float) * y_dim.production()));
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-1, 1);
for (int i = 0; i < x_dim.production(); ++i) {
mapped_x[i] = dist(engine);
}
auto *pad2d_in_data =
pad2d_in.mutable_data<half_t, cl::Image2D>(
pad2d_image2d_shape["width"],
pad2d_image2d_shape["height"]);
auto *pad2d_out_data =
pad2d_out.mutable_data<half_t, cl::Image2D>(y_dim[3],
y_dim[2]);
// set context and kernel args
LOG(INFO) << "set context and kernel args";
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
buf_to_img_kernel->SetParam(BufferToImageParam);
std::unique_ptr<KernelContext> buf_to_img_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(buf_to_img_context->As<OpenCLContext>()));
buf_to_img_kernel->SetContext(
std::move(buf_to_img_context));
img_to_buf_kernel->SetParam(ImageToBufferParam);
std::unique_ptr<KernelContext> img_to_buf_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(img_to_buf_context->As<OpenCLContext>()));
img_to_buf_kernel->SetContext(
std::move(img_to_buf_context));
pad2d_img_kernel->SetParam(Pad2dParam);
std::unique_ptr<KernelContext> pad2d_img_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(pad2d_img_context->As<OpenCLContext>()));
pad2d_img_kernel->SetContext(
std::move(pad2d_img_context));
// run kernels
LOG(INFO) << "run kernel: buf_to_img_kernel";
buf_to_img_kernel->Launch();
LOG(INFO) << "run kernel: pad2d_img_kernel";
pad2d_img_kernel->Launch();
LOG(INFO) << "run kernel: img_to_buf_kernel";
img_to_buf_kernel->Launch();
// wait for opencl
auto *wait_list =
context->As<OpenCLContext>().cl_wait_list();
auto *out_ptr =
ImageToBufferParam.y->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL)
<< "Could not find the sync event for the target "
"cl tensor.";
}
// compute ref cpu
pad2d_ref(mapped_x,
&y_ref,
pad_mode,
pad_h0,
pad_h1,
pad_w0,
pad_w1,
pad_value);
// result
#ifdef PRINT_RESULT
LOG(INFO)
<< "---- print kernel result (input -> output) ----";
for (int eidx = 0; eidx < x_dim.production(); ++eidx) {
std::cout << mapped_x[eidx] << " ";
}
std::cout << std::endl;
for (int eidx = 0; eidx < y_dim.production(); ++eidx) {
std::cout << mapped_y[eidx] << " ";
}
std::cout << std::endl;
for (int eidx = 0; eidx < y_dim.production(); ++eidx) {
std::cout << y_data_ref[eidx] << " ";
}
std::cout << std::endl;
#endif // PRINT_RESULT
// check result: compare kernel output and cpu
// output(y_data_ref)
for (int eidx = 0; eidx < y_dim.production(); eidx++) {
EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], 1e-3);
if (abs(y_data_ref[eidx] - mapped_y[eidx]) > 1e-3) {
LOG(FATAL) << "1st diff in this case at eidx[from 0]:"
<< eidx << " / " << y_dim.production()
<< ", y_data_ref[" << eidx
<< "]:" << y_data_ref[eidx]
<< ", mapped_y[" << eidx
<< "]:" << mapped_y[eidx];
break;
}
}
// free
LOG(INFO) << "free: unmap x, y";
TargetWrapperCL::Unmap(x_data, mapped_x);
TargetWrapperCL::Unmap(y_data, mapped_y);
#ifdef LOOP_TEST
} // pad_mode
} // pad_value
} // pad_w1
} // pad_w0
} // pad_h1
} // pad_h0
} // w
} // h
} // c
} // n
#else
// nothing to do.
#endif
}
} // namespace lite
} // namespace paddle
// pad2d image2d fp32
USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault);
USE_LITE_KERNEL(layout, kOpenCL, kAny, kNCHW, ImageDefault_to_NCHW);
// pad image2d fp16
USE_LITE_KERNEL(pad2d, kOpenCL, kFP16, kImageDefault, ImageDefault);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册