“a615ad46e4f8a232deae9db7558e5277f057e208”上不存在“python/paddle/fluid/tests/unittests/test_dist_mnist_train.py”
提交 41c52d3b 编写于 作者: H hedaoyuan

Modify the argument type of ContextProjectionFunc

上级 68156c88
......@@ -3,6 +3,7 @@ file(GLOB cpp_files . *Op.cpp)
list(APPEND h_files Function.h)
list(APPEND cpp_files Function.cpp)
list(APPEND cpp_files BufferArg.cpp)
if(WITH_GPU)
file(GLOB cu_files . *OpGpu.cu)
......@@ -16,10 +17,13 @@ if(WITH_TESTING)
# TODO:
# file(GLOB test_files . *OpTest.cpp)
# add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files})
add_simple_unittest(CrossMapNormalOpTest)
add_unittest(ContextProjectionOpTest
ContextProjectionOpTest.cpp
../gserver/tests/TestUtil.cpp)
# add_simple_unittest(CrossMapNormalOpTest)
add_simple_unittest(TensorShapeTest)
add_simple_unittest(TensorTypeTest)
add_simple_unittest(BufferArgTest)
# add_unittest(ContextProjectionOpTest
# ContextProjectionOpTest.cpp
# ../gserver/tests/TestUtil.cpp)
endif()
endif()
......
......@@ -19,17 +19,15 @@ limitations under the License. */
namespace paddle {
template <>
void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix* out_mat,
const CpuMatrix* input_mat,
const CpuMatrix* weight_mat,
void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
const CpuMatrix& input_mat,
const CpuMatrix& weight_mat,
const CpuIVector& seq_vec,
size_t context_length,
int context_start,
size_t begin_pad) {
const int* starts = seq_vec.getData();
const size_t num_sequences = seq_vec.getSize() - 1;
auto w_mat = const_cast<CpuMatrix*>(weight_mat);
auto in_mat = const_cast<CpuMatrix*>(input_mat);
for (size_t i = 0; i < num_sequences; ++i) {
for (size_t j = 0; j < context_length; ++j) {
int begin = starts[i] + context_start + j;
......@@ -39,10 +37,11 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix* out_mat,
if (begin < starts[i]) {
int64_t pad_size =
std::min(starts[i] - begin, starts[i + 1] - starts[i]);
MatrixPtr mat = out_mat->subMatrix(starts[i], pad_size);
if (w_mat) {
MatrixPtr sub = w_mat->subMatrix(j, pad_size);
mat->addAtOffset(*sub, j * in_mat->getWidth());
MatrixPtr mat = out_mat.subMatrix(starts[i], pad_size);
if (weight_mat) {
MatrixPtr sub =
const_cast<CpuMatrix&>(weight_mat).subMatrix(j, pad_size);
mat->addAtOffset(*sub, j * input_mat.getWidth());
}
dst_begin = starts[i] + pad_size;
begin = starts[i];
......@@ -50,19 +49,22 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix* out_mat,
if (end > starts[i + 1]) {
int64_t pad_size =
std::min(end - starts[i + 1], starts[i + 1] - starts[i]);
MatrixPtr mat = out_mat->subMatrix(starts[i + 1] - pad_size, pad_size);
if (w_mat) {
MatrixPtr sub = w_mat->subMatrix(
begin_pad + context_start + j - pad_size, pad_size);
mat->addAtOffset(*sub, j * in_mat->getWidth());
MatrixPtr mat = out_mat.subMatrix(starts[i + 1] - pad_size, pad_size);
if (weight_mat) {
MatrixPtr sub =
const_cast<CpuMatrix&>(weight_mat)
.subMatrix(begin_pad + context_start + j - pad_size,
pad_size);
mat->addAtOffset(*sub, j * input_mat.getWidth());
}
dst_end = starts[i + 1] - pad_size;
end = starts[i + 1];
}
if (end <= begin) continue;
MatrixPtr src = in_mat->subMatrix(begin, end - begin);
MatrixPtr dst = out_mat->subMatrix(dst_begin, dst_end - dst_begin);
dst->addAtOffset(*src, j * in_mat->getWidth());
MatrixPtr src =
const_cast<CpuMatrix&>(input_mat).subMatrix(begin, end - begin);
MatrixPtr dst = out_mat.subMatrix(dst_begin, dst_end - dst_begin);
dst->addAtOffset(*src, j * input_mat.getWidth());
}
}
}
......@@ -82,40 +84,34 @@ public:
begin_pad_ = config.get<size_t>("begin_pad");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
void calc(const BufferArgs& inputs,
const BufferArgs& outputs,
const BufferArgs& inouts) override {
CHECK_EQ(3, inputs.size());
CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK(outputs[0].getData() && inputs[0].getData() && inputs[2].getData());
CHECK_EQ(outputs[0].dims_.size(), 2);
CHECK_EQ(inputs[0].dims_.size(), 2);
CHECK_EQ(inputs[1].dims_.size(), 2);
CHECK_EQ(inputs[2].dims_.size(), 1);
CHECK(outputs[0].data() && inputs[0].data() && inputs[2].data());
CHECK_EQ(outputs[0].shape().ndims(), 2);
CHECK_EQ(inputs[0].shape().ndims(), 2);
CHECK_EQ(inputs[1].shape().ndims(), 2);
CHECK_EQ(inputs[2].shape().ndims(), 1);
/// dim of output = dim of input * context_length
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_);
/// dim of input == dim of weight
CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]);
CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]);
/// input and output has the same batch_size
CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]);
auto out_mat = std::make_shared<typename MatrixT<Device>::type>(
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
const auto in_mat = std::make_shared<typename MatrixT<Device>::type>(
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
const auto w_mat =
!inputs[1].getData()
? nullptr
: std::make_shared<typename MatrixT<Device>::type>(
inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]);
typename SequenceT<Device>::type seq_vec(
inputs[2].dims_[0], reinterpret_cast<int*>(inputs[2].getData()));
ContextProjectionForward<Device>(out_mat.get(),
in_mat.get(),
w_mat.get(),
CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
auto out_mat = outputs[0].matrix<Device>();
auto in_mat = inputs[0].matrix<Device>();
auto w_mat = !inputs[1].data()
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: inputs[1].matrix<Device>();
auto seq_vec = inputs[2].vector<int, Device>();
ContextProjectionForward<Device>(out_mat,
in_mat,
w_mat,
seq_vec,
context_length_,
context_start_,
......@@ -129,18 +125,17 @@ private:
};
template <>
void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
CpuMatrix* in_grad_mat,
CpuMatrix* w_grad_mat,
void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix& out_grad_mat,
CpuMatrix& in_grad_mat,
CpuMatrix& w_grad_mat,
const CpuIVector& seq_vec,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad) {
CHECK(out_grad_mat);
size_t input_dim = in_grad_mat ? in_grad_mat->getWidth()
: w_grad_mat ? w_grad_mat->getWidth() : 0;
size_t input_dim = in_grad_mat ? in_grad_mat.getWidth()
: w_grad_mat ? w_grad_mat.getWidth() : 0;
const int* starts = seq_vec.getData();
size_t num_sequences = seq_vec.getSize() - 1;
for (size_t i = 0; i < num_sequences; ++i) {
......@@ -153,8 +148,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
int64_t pad_size =
std::min(starts[i] - begin, starts[i + 1] - starts[i]);
if (is_padding && w_grad_mat) {
MatrixPtr mat = out_grad_mat->subMatrix(starts[i], pad_size);
MatrixPtr sub = w_grad_mat->subMatrix(j, pad_size);
MatrixPtr mat = out_grad_mat.subMatrix(starts[i], pad_size);
MatrixPtr sub = w_grad_mat.subMatrix(j, pad_size);
sub->addAtOffset(*mat, j * input_dim);
}
dst_begin = starts[i] + pad_size;
......@@ -165,8 +160,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
std::min(end - starts[i + 1], starts[i + 1] - starts[i]);
if (is_padding && w_grad_mat) {
MatrixPtr mat =
out_grad_mat->subMatrix(starts[i + 1] - pad_size, pad_size);
MatrixPtr sub = w_grad_mat->subMatrix(
out_grad_mat.subMatrix(starts[i + 1] - pad_size, pad_size);
MatrixPtr sub = w_grad_mat.subMatrix(
begin_pad + context_start + j - pad_size, pad_size);
sub->addAtOffset(*mat, j * input_dim);
}
......@@ -175,8 +170,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
}
if (end <= begin) continue;
if (!in_grad_mat) continue;
MatrixPtr src = in_grad_mat->subMatrix(begin, end - begin);
MatrixPtr dst = out_grad_mat->subMatrix(dst_begin, dst_end - dst_begin);
MatrixPtr src = in_grad_mat.subMatrix(begin, end - begin);
MatrixPtr dst = out_grad_mat.subMatrix(dst_begin, dst_end - dst_begin);
src->addAtOffset(*dst, j * input_dim);
}
}
......@@ -199,44 +194,37 @@ public:
total_pad_ = config.get<size_t>("total_pad");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
void calc(const BufferArgs& inputs,
const BufferArgs& outputs,
const BufferArgs& inouts) override {
CHECK_EQ(3, inputs.size());
CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK(outputs[0].getData() && inputs[2].getData());
CHECK_EQ(outputs[0].dims_.size(), 2);
CHECK_EQ(inputs[0].dims_.size(), 2);
CHECK_EQ(inputs[1].dims_.size(), 2);
CHECK_EQ(inputs[2].dims_.size(), 1);
CHECK(outputs[0].data() && inputs[2].data());
CHECK_EQ(outputs[0].shape().ndims(), 2);
CHECK_EQ(inputs[0].shape().ndims(), 2);
CHECK_EQ(inputs[1].shape().ndims(), 2);
CHECK_EQ(inputs[2].shape().ndims(), 1);
/// dim of input == dim of weight
CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]);
CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]);
/// input and output has the same batch_size
CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]);
CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
/// dim of output = dim of input * context_length
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_);
auto out_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
auto out_grad_mat = outputs[0].matrix<Device>();
auto in_grad_mat =
!inputs[0].getData()
? nullptr
: std::make_shared<typename MatrixT<Device>::type>(
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
auto w_grad_mat =
!inputs[1].getData()
? nullptr
: std::make_shared<typename MatrixT<Device>::type>(
inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]);
typename SequenceT<Device>::type seq_vec(
inputs[2].dims_[0], reinterpret_cast<int*>(inputs[2].getData()));
ContextProjectionBackward<Device>(out_grad_mat.get(),
in_grad_mat ? in_grad_mat.get() : nullptr,
w_grad_mat ? w_grad_mat.get() : nullptr,
!inputs[0].data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: inputs[0].matrix<Device>();
auto w_grad_mat = !inputs[1].data()
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: inputs[1].matrix<Device>();
auto seq_vec = inputs[2].vector<int, Device>();
ContextProjectionBackward<Device>(out_grad_mat,
in_grad_mat,
w_grad_mat,
seq_vec,
context_length_,
context_start_,
......@@ -253,6 +241,7 @@ private:
size_t total_pad_;
};
#if 0
/**
* \param inputs[0] input grad.
* \param inputs[1] input sequence.
......@@ -272,6 +261,7 @@ public:
CHECK_EQ(2, inputs.size());
CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK(inputs[0].getData() && outputs[0].getData() && inputs[1].getData());
CHECK_EQ(outputs[0].dims_.size(), 2);
CHECK_EQ(inputs[0].dims_.size(), 2);
......@@ -349,6 +339,7 @@ private:
size_t begin_pad_;
size_t total_pad_;
};
#endif
REGISTER_TYPED_FUNC(ContextProjectionForward,
CPU,
......@@ -363,6 +354,7 @@ REGISTER_TYPED_FUNC(ContextProjectionForward,
REGISTER_TYPED_FUNC(ContextProjectionBackward,
GPU,
ContextProjectionBackwardFunc);
#if 0
REGISTER_TYPED_FUNC(ContextProjectionBackwardData,
GPU,
ContextProjectionBackwardDataFunc);
......@@ -370,4 +362,5 @@ REGISTER_TYPED_FUNC(ContextProjectionBackwardWeight,
GPU,
ContextProjectionBackwardWeightFunc);
#endif
#endif
} // namespace paddle
......@@ -31,14 +31,15 @@ namespace paddle {
* \param[in] is_padding whether padding 0 or not.
*
*/
template <DeviceType Device>
void ContextProjectionForward(typename MatrixT<Device>::type* output,
const typename MatrixT<Device>::type* input,
const typename MatrixT<Device>::type* weight,
const typename SequenceT<Device>::type& sequence,
size_t context_length,
int context_start,
size_t begin_pad);
template <DeviceType DType>
void ContextProjectionForward(
typename Tensor<real, DType>::Matrix& output,
const typename Tensor<real, DType>::Matrix& input,
const typename Tensor<real, DType>::Matrix& weight,
const typename Tensor<int, DType>::Vector& sequence,
size_t context_length,
int context_start,
size_t begin_pad);
/**
* \brief Context Projection Backward.
......@@ -53,30 +54,31 @@ void ContextProjectionForward(typename MatrixT<Device>::type* output,
* \param[in] is_padding whether padding 0 or not.
*
*/
template <DeviceType Device>
void ContextProjectionBackward(typename MatrixT<Device>::type* out_grad,
typename MatrixT<Device>::type* in_grad,
typename MatrixT<Device>::type* w_grad,
const typename SequenceT<Device>::type& seq_vec,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad);
template <DeviceType DType>
void ContextProjectionBackward(
typename Tensor<real, DType>::Matrix& out_grad,
typename Tensor<real, DType>::Matrix& in_grad,
typename Tensor<real, DType>::Matrix& w_grad,
const typename Tensor<int, DType>::Vector& seq_vec,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad);
template <DeviceType Device>
template <DeviceType DType>
void ContextProjectionBackwardData(
typename MatrixT<Device>::type* out_grad,
typename MatrixT<Device>::type* in_grad,
const typename SequenceT<Device>::type& sequence,
typename Tensor<real, DType>::Matrix& out_grad,
typename Tensor<real, DType>::Matrix& in_grad,
const typename Tensor<int, DType>::Vector& sequence,
size_t context_length,
int context_start);
template <DeviceType Device>
template <DeviceType DType>
void ContextProjectionBackwardWeight(
typename MatrixT<Device>::type* out_grad,
typename MatrixT<Device>::type* w_grad,
const typename SequenceT<Device>::type& seq_vec,
typename Tensor<real, DType>::Matrix& out_grad,
typename Tensor<real, DType>::Matrix& w_grad,
const typename Tensor<int, DType>::Vector& seq_vec,
size_t context_length,
int context_start,
size_t total_pad,
......
......@@ -120,20 +120,19 @@ void hl_context_projection_forward(const real* input,
}
template <>
void ContextProjectionForward<DEVICE_TYPE_GPU>(GpuMatrix* output,
const GpuMatrix* input,
const GpuMatrix* weight,
void ContextProjectionForward<DEVICE_TYPE_GPU>(GpuMatrix& output,
const GpuMatrix& input,
const GpuMatrix& weight,
const GpuIVector& sequence,
size_t context_length,
int context_start,
size_t begin_pad) {
CHECK(input && output);
hl_context_projection_forward(input->getData(),
hl_context_projection_forward(input.getData(),
sequence.getData(),
weight ? weight->getData() : nullptr,
output->getData(),
weight ? weight.getData() : nullptr,
output.getData(),
sequence.getSize() - 1,
input->getWidth(),
input.getWidth(),
context_length,
context_start,
begin_pad);
......@@ -217,17 +216,16 @@ void hl_context_projection_backward_data(real* out_grad,
}
template <>
void ContextProjectionBackwardData<DEVICE_TYPE_GPU>(GpuMatrix* out_grad,
GpuMatrix* in_grad,
void ContextProjectionBackwardData<DEVICE_TYPE_GPU>(GpuMatrix& out_grad,
GpuMatrix& in_grad,
const GpuIVector& sequence,
size_t context_length,
int context_start) {
CHECK(in_grad && out_grad);
hl_context_projection_backward_data(out_grad->getData(),
hl_context_projection_backward_data(out_grad.getData(),
sequence.getData(),
in_grad->getData(),
in_grad.getData(),
sequence.getSize() - 1,
in_grad->getWidth(),
in_grad.getWidth(),
context_length,
context_start);
}
......@@ -348,19 +346,18 @@ void hl_context_projection_backward_weight(real* out_grad,
template <>
void ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(
GpuMatrix* out_grad,
GpuMatrix* w_grad,
GpuMatrix& out_grad,
GpuMatrix& w_grad,
const GpuIVector& seq_vec,
size_t context_length,
int context_start,
size_t total_pad,
size_t begin_pad) {
CHECK(out_grad && w_grad);
hl_context_projection_backward_weight(out_grad->getData(),
hl_context_projection_backward_weight(out_grad.getData(),
seq_vec.getData(),
w_grad->getData(),
w_grad.getData(),
seq_vec.getSize() - 1,
w_grad->getWidth(),
w_grad.getWidth(),
total_pad,
context_length,
context_start,
......@@ -368,16 +365,15 @@ void ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(
}
template <>
void ContextProjectionBackward<DEVICE_TYPE_GPU>(GpuMatrix* out_grad,
GpuMatrix* in_grad,
GpuMatrix* w_grad,
void ContextProjectionBackward<DEVICE_TYPE_GPU>(GpuMatrix& out_grad,
GpuMatrix& in_grad,
GpuMatrix& w_grad,
const GpuIVector& sequence,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad) {
CHECK(out_grad);
if (in_grad) {
ContextProjectionBackwardData<DEVICE_TYPE_GPU>(
out_grad,
......
......@@ -44,4 +44,21 @@ TEST(TensorType, Vector) {
EXPECT_EQ(gpuIVector.getSize(), 100);
}
TEST(TensorType, EmptyMatrix) {
CpuMatrix empty(nullptr, 0, 0);
CpuMatrix nonEmpty(10, 10);
EXPECT_EQ(empty.isEmpty(), true);
EXPECT_EQ(nonEmpty.isEmpty(), false);
CHECK(nonEmpty);
auto function = [](const CpuMatrix& matrix) {
if (matrix) {
EXPECT_NE(matrix.getData(), nullptr);
} else {
EXPECT_EQ(matrix.getData(), nullptr);
}
};
function(empty);
function(nonEmpty);
}
} // namespace paddle
......@@ -110,7 +110,7 @@ void ContextProjection::forward() {
size_t input_dim = in_->value->getWidth();
size_t dim = out_->value->getWidth();
CHECK_EQ(dim, input_dim * config_.context_length());
size_t batch_size = in_->value->getHeight();
// size_t batch_size = in_->value->getHeight();
CHECK_EQ(forward_.size(), 1) << "Only one forward function here";
REGISTER_TIMER_INFO("ContextProjectionForward", getName().c_str());
......@@ -119,14 +119,17 @@ void ContextProjection::forward() {
auto w_ptr =
state_ ? state_.get() : is_padding ? weight_->getW().get() : nullptr;
auto start_pos = in_->sequenceStartPositions;
forward_[0]->calc({Tensor(in_->value->getData(), Dims{batch_size, input_dim}),
Tensor(w_ptr ? w_ptr->getData() : nullptr,
Dims{w_ptr ? w_ptr->getHeight() : 0, input_dim}),
Tensor(reinterpret_cast<real*>(
const_cast<int*>(start_pos->getData(useGpu_))),
Dims{start_pos->getSize()})},
{Tensor(out_->value->getData(), Dims{batch_size, dim})},
{});
BufferArgs inputs;
BufferArgs outputs;
BufferArgs inouts;
inputs.addArg(*in_->value);
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0,
input_dim));
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
outputs.addArg(*out_->value);
forward_[0]->calc(inputs, outputs, inouts);
if (state_ && config_.context_start() < 0) {
CHECK_EQ(1, in_->getNumSequences());
......@@ -160,15 +163,18 @@ void ContextProjection::backward(const UpdateCallback& callback) {
bool is_padding = config_.trainable_padding();
auto start_pos = in_->sequenceStartPositions;
auto w_ptr = is_padding ? weight_->getWGrad() : nullptr;
backward_[0]->calc({Tensor(in_->grad ? in_->grad->getData() : nullptr,
Dims{batch_size, input_dim}),
Tensor(w_ptr ? w_ptr->getData() : nullptr,
Dims{w_ptr ? w_ptr->getHeight() : 0, input_dim}),
Tensor(reinterpret_cast<real*>(
const_cast<int*>(start_pos->getData(useGpu_))),
Dims{start_pos->getSize()})},
{Tensor(out_->grad->getData(), Dims{batch_size, dim})},
{});
BufferArgs inputs;
BufferArgs outputs;
BufferArgs inouts;
inputs.addArg(CpuMatrix(
in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim));
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0,
input_dim));
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
outputs.addArg(*out_->grad);
backward_[0]->calc(inputs, outputs, inouts);
if (config_.trainable_padding()) {
weight_->getParameterPtr()->incUpdate(callback);
......
......@@ -1091,6 +1091,10 @@ public:
TensorCpuApply<real>(*this, expr);
}
}
bool isEmpty() const { return data_ == nullptr; }
explicit operator bool() const { return !isEmpty(); }
};
inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <memory>
#include <thread>
#include "paddle/utils/Logging.h"
#include "paddle/utils/ThreadLocal.h"
#include <hl_gpu.h>
#include "BaseMatrix.h"
#include "MemoryHandle.h"
#include "Vector.h"
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/common.h"
namespace paddle {
enum SparseValueType { NO_VALUE = 0, FLOAT_VALUE = 1 };
/**
* @brief matrix sparse_format .
*
* nnz represents nonzero number in sparse matrix.
*
* SPARSE_CSR: row major matrix. length of row is height_ + 1, each element
* represents row start index in Matrix. length of col and value are nnz.
*
* SPARSE_CSC: col major matrix. length of col is width_ + 1, each element
* represents col start index in Matrix. length of col and value are nnz.
*
* @code
* for example: [0, 1, 0, 2, 0;
* 1, 0, 0, 0, 0;
* 0, 0, 0, 2, 5];
* SPARSE_CSR row [0, 2, 3, 5];
* col [1, 3, 0, 3, 4];
* value [1, 2, 1, 2, 5]
* SPARSE_CSC col [0, 1, 2, 2, 4, 5];
* row [1, 0, 0, 2, 2];
* value [1, 1, 2, 2, 5]
* @endcode
*/
enum SparseFormat { SPARSE_CSR = 0, SPARSE_CSC = 1 };
class Matrix;
class GpuMatrix;
class CpuMatrix;
class CpuSparseMatrix;
class GpuSparseMatrix;
typedef std::shared_ptr<Matrix> MatrixPtr;
typedef std::shared_ptr<GpuMatrix> GpuMatrixPtr;
typedef std::shared_ptr<CpuMatrix> CpuMatrixPtr;
typedef std::shared_ptr<GpuSparseMatrix> GpuSparseMatrixPtr;
typedef std::shared_ptr<CpuSparseMatrix> CpuSparseMatrixPtr;
/**
* Copy or assignemnt constructor will share the data as opposed to making a
* copy of the original data. To make a copy of the orinal data, use copyFrom()
* instead.
*/
class Matrix : public BaseMatrix {
protected:
Matrix(MemoryHandlePtr memHandle,
size_t height,
size_t width,
bool trans,
bool use_gpu);
Matrix(real* data, size_t height, size_t width, bool trans, bool use_gpu);
Matrix(real* data,
size_t height,
size_t width,
size_t stride,
bool trans,
bool use_gpu);
static ThreadLocal<MatrixPtr> tmpMat_;
public:
size_t elementCnt_; // maximal number of elements which can be held in data_
MemoryHandlePtr memoryHandle_;
public:
virtual ~Matrix() {}
static MatrixPtr create(MemoryHandlePtr memHandle,
size_t height,
size_t width,
bool trans = false);
static MatrixPtr create(size_t height,
size_t width,
bool trans = false,
bool useGpu = false);
static MatrixPtr create(real* data,
size_t height,
size_t width,
bool trans = false,
bool useGpu = false);
static MatrixPtr create(real* data,
size_t height,
size_t width,
size_t stride,
bool trans = false,
bool useGpu = false);
static MatrixPtr createSparseMatrix(size_t height,
size_t width,
size_t nnz,
SparseValueType valueType = FLOAT_VALUE,
bool trans = false,
bool useGpu = false);
static MatrixPtr createSparseMatrix(size_t height,
size_t width,
size_t nnz,
SparseValueType valueType = FLOAT_VALUE,
SparseFormat foramt = SPARSE_CSR,
bool trans = false,
bool useGpu = false);
static MatrixPtr createSparseMatrix(real* data,
int* row,
int* col,
size_t height,
size_t width,
size_t nnz, /* used to allocate space */
SparseValueType valueType, /*value type*/
SparseFormat format,
bool trans,
bool useGpu);
static void resizeOrCreateSparseMatrix(
MatrixPtr& matrix,
size_t height,
size_t width,
size_t nnz,
SparseValueType valueType = FLOAT_VALUE,
SparseFormat foramt = SPARSE_CSR,
bool trans = false,
bool useGpu = false);
static void resizeOrCreate(MatrixPtr& a,
size_t height,
size_t width,
bool trans = false,
bool useGpu = false);
/**
* @brief set the data buffer used to hold the matrix data.
*
* caller should make sure that the size of data is at least
* sizeof(real)*height*width.
*/
void setData(real* data) {
BaseMatrix::setData(data);
memoryHandle_.reset();
}
/// the data should be contiguous
void setData(real* data, size_t newHeight, size_t newWidth) {
setData(data);
height_ = newHeight;
width_ = newWidth;
elementCnt_ = newHeight * newWidth;
stride_ = width_;
}
size_t getWidth() const { return width_; }
size_t getHeight() const { return height_; }
size_t getStride() const { return stride_; }
size_t getElementCnt() const { return elementCnt_; }
virtual real* getData() { return data_; }
virtual const real* getData() const { return data_; }
bool isTransposed() const { return trans_; }
bool isContiguous() const { return stride_ == width_ || height_ == 1; }
// If sparse matrix, need to dynamic_cast to CpuSparseMatrix/GpuSparseMatrix
// befor call the following functions.
// Declare these functions in the base class just easy to call them.
// And these declarations should be moved to base class of sparse matrix
// if refactor sparse matrix
virtual int* getRows() const {
LOG(FATAL) << "Not implemented";
return nullptr; //! suppress warning for no return value.
}
virtual int* getCols() const {
LOG(FATAL) << "Not implemented";
return nullptr; //! suppress warning for no return value.
}
virtual SparseFormat getFormat() const {
LOG(FATAL) << "Not implemented";
return SPARSE_CSR; //! suppress warning for no return value.
}
virtual SparseValueType getValueType() const {
LOG(FATAL) << "Not implemented";
return NO_VALUE; //! suppress warning for no return value.
}
/**
* @brief matrix elment-wise add
*
* Named add3 just because add/add2 has been used in BaseMatrix.cu
* and they are not virtual function.
*/
virtual void add3(MatrixPtr b) { LOG(FATAL) << "Not implemented"; }
MemoryHandlePtr getMemoryHandle() const { return memoryHandle_; }
virtual void zeroMem() { LOG(FATAL) << "Not implemented"; }
virtual void resetOne() { LOG(FATAL) << "Not implemented"; }
void setDiag(real value);
virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; }
virtual void trimFrom(const CpuSparseMatrix& src) {
LOG(FATAL) << "Not implemented";
}
// asynchronous copy
virtual void copyFrom(const Matrix& src, hl_stream_t stream) {
LOG(FATAL) << "Not implemented";
}
MatrixPtr subMatrix(size_t startRow,
size_t endRow,
size_t startCol,
size_t endCol);
MatrixPtr subRowMatrix(size_t startRow, size_t endRow) {
return subMatrix(startRow, endRow, 0, getWidth());
}
MatrixPtr subColMatrix(size_t startCol, size_t endCol) {
return subMatrix(0, getHeight(), startCol, endCol);
}
virtual MatrixPtr subMatrix(size_t startRow, size_t numRows) {
CHECK_LE(startRow + numRows, getHeight());
return Matrix::create(getData() + startRow * getWidth(),
numRows,
getWidth(),
trans_,
useGpu_);
}
virtual MatrixPtr subMatrix(size_t startRow, size_t numRows, MatrixPtr dest) {
CHECK_LE(startRow + numRows, getHeight());
CHECK_EQ(useGpu_, dest->useGpu_);
dest->setData(this->rowBuf(startRow), numRows, getWidth());
return dest;
}
/**
* If this is GpuMatrix, src is assumed to be CPU memory
*
* If this is CpuMatrix, src is assumed to be CPU memory
*/
virtual void copyFrom(const real* src, size_t size) {
LOG(FATAL) << "Not implemented";
}
virtual void copyFrom(const real* src, const int64_t* seq) {
LOG(FATAL) << "Not implemented";
}
/**
* @brief convert a int vector to a real matrix.
*
* (1) source and dest are both in CPU.
*
* (2) sizes are exactly match.
*/
virtual void copyFrom(const IVector& src) {
LOG(FATAL) << "copy data from int vector only available on CpuMatrix.";
}
virtual void copyByRowIndex(Matrix& b, const IVector& rowIndex) {
LOG(FATAL) << "Not implemented";
}
/**
* @brief Create a matrix with the same type (GpuMatrix, CpuMatrix,
* NonValueSparseMatrix, etc.) as this.
*
* If height and width is zero, the new matrix will have the same size
* as this, otherwise the new matrix will have the specified size.
*
*/
virtual MatrixPtr clone(size_t height = 0,
size_t width = 0,
bool useGpu = false) {
LOG(FATAL) << "Not implemented";
return nullptr;
}
virtual real* getRowBuf(size_t row) {
LOG(FATAL) << "Not implemented";
return nullptr;
}
virtual real getElement(size_t x, size_t y) const {
LOG(FATAL) << "Not implemented";
return 0;
}
virtual real getSum() {
LOG(FATAL) << "Not implemented";
return 0;
}
virtual void accumulateColSum(Matrix& src) {
LOG(FATAL) << "Not implemented";
}
virtual real getAbsSum() {
LOG(FATAL) << "Not implemented";
return 0;
}
/**
* @note Original data may not be preserved after resize().
*/
virtual void resize(size_t newHeight, size_t newWidth) = 0;
/**
* @note This should only be used for sparse matrix.
*/
virtual void resize(size_t newHeight,
size_t newWidth,
size_t newNnz, /* total item used to allocate space */
SparseValueType valueType,
SparseFormat format) = 0;
/**
* @brief This should only be used for sparse matrix.
*
* Currently must be called for each row in order.
* The matrix is not valid until setRow is called for the last row.
*/
virtual void setRow(size_t row,
size_t colNum,
const unsigned int* cols,
const real* values) = 0;
virtual MatrixPtr getTranspose() = 0;
/**
* @brief hard transpose.
*
* allocate matTrans' memory outside, then set memAlloc as false;
* else set as true.
*/
virtual void transpose(MatrixPtr matTrans, bool memAlloc) {
LOG(FATAL) << "Not implemented";
}
virtual MatrixPtr getInverse() {
LOG(FATAL) << "Not implemented";
return nullptr;
}
/**
* @brief inverse.
*
* if allocate matInv's memory outside, then set memAlloc as false;
* else set as true.
*/
virtual void inverse(MatrixPtr matInv, bool memAlloc) {
LOG(FATAL) << "Not implemented";
}
public:
/// Only set all variables to 0 or NULL but not free them.
virtual void clear() {
height_ = 0;
width_ = 0;
data_ = NULL;
}
void reshape(size_t height, size_t width);
/// add b to each sample of this.
virtual void addBias(Matrix& b, real scale) {
LOG(FATAL) << "Not implemented";
}
virtual void addSharedBias(Matrix& b, real scale) {
LOG(FATAL) << "Not implemented";
}
void addBias(Matrix& b, real scale, bool sharedBias) {
if (!sharedBias) {
addBias(b, scale);
} else {
addSharedBias(b, scale);
}
}
/// add each sample from a to this.
virtual void collectBias(Matrix& a, real scale) {
LOG(FATAL) << "Not implemented";
}
virtual void collectSharedBias(Matrix& a, real scale) {
LOG(FATAL) << "Not implemented";
}
void collectBias(Matrix& a, real scale, bool sharedBias) {
if (!sharedBias) {
collectBias(a, scale);
} else {
collectSharedBias(a, scale);
}
}
virtual void sequenceAvgForward(Matrix& a,
const IVector& startsPos,
int mode) {
LOG(FATAL) << "Not implemented";
}
/**
* @code
* this = scaleAB*(a*b) + scaleT*this
* @endcode
*/
virtual void mul(const Matrix& a,
const Matrix& b,
real scaleAB,
real scaleT) {
LOG(FATAL) << "Not implemented";
}
/// Add a vector (column) b to matrix a, column by column.
virtual void addColumnVector(const Matrix& b) {
LOG(FATAL) << "Not implemented";
}
/**
* @code
* For j < codeLength:
* this(i, j) += vec(index(i, j), 0)
* where index(i, j) = ((codes(i) + numClasses) >> (j + 1)) - 1
* @endcode
*/
virtual void addByBitCode(size_t numClasses,
const IVector& codes,
const Matrix& vec) {
(void)numClasses;
(void)codes;
(void)vec;
LOG(FATAL) << "Not implemeted";
}
/**
* @code
* For j < codeLength:
* vec(index(i, j), 0) += this(i, j)
* where index is same as the index for addByBitCode
* @endcode
*/
virtual void addByBitCodeBackward(size_t numClasses,
const IVector& codes,
Matrix& vec) {
(void)numClasses;
(void)codes;
(void)vec;
LOG(FATAL) << "Not implemeted";
}
/**
* @code
* For j < codeLength:
* this(i, j) += <mat.row(index(i, j)), input.row(i)>
* where index is same as the index for addByBitCode
* @endcode
*/
virtual void mulByBitCode(size_t numClasses,
const IVector& codes,
const Matrix& mat,
const Matrix& input) {
(void)numClasses;
(void)codes;
(void)mat;
(void)input;
LOG(FATAL) << "Not implemeted";
}
/**
* @code
* For j < codeLength:
* mat.row(index(i, j)) += this(i, j) * input.row(i)
* where index is same as the index for addByBitCode
* @endcode
*/
virtual void mulByBitCodeBackwardWeight(size_t numClasses,
const IVector& codes,
Matrix& mat,
const Matrix& input) {
(void)numClasses;
(void)codes;
(void)mat;
(void)input;
LOG(FATAL) << "Not implemeted";
}
/**
* @code
* For j < codeLength:
* input.row(i) += this(i, j) * mat.row(index(i, j))
* where index is same as the index for addByBitCode
* @endcode
*/
virtual void mulByBitCodeBackwardError(size_t numClasses,
const IVector& codes,
const Matrix& mat,
Matrix& input) {
(void)numClasses;
(void)codes;
(void)mat;
(void)input;
LOG(FATAL) << "Not implemeted";
}
/**
* @code
* For j < codeLength
* sum(i, 0) = scaleSum * \sum_j bit(i, j) * this(i, j)
* where bit(i, j) = ((codes(i) + numClasses) & 2^j) ? 1 : 0
* @endcode
*/
virtual void sumByBitCode(size_t numClasses,
IVector& codes,
Matrix& sum,
real scaleSum) {
(void)numClasses;
(void)codes;
(void)sum;
(void)scaleSum;
LOG(FATAL) << "Not implemeted";
}
/**
* @code
* For j < codeLength
* this(i, j) -= bit(i, j)
* where bit(i, j) is same as that for sumByBitCode
* @endcode
*/
virtual void subByBitCode(size_t numClasses_, IVector& codes) {
(void)numClasses_;
(void)codes;
LOG(FATAL) << "Not implemeted";
}
/**
* add the sum of each row of this to mat
*/
virtual void rowSum(Matrix& sum) {
(void)sum;
LOG(FATAL) << "Not implemeted";
}
/**
* set the max of each row of this to mat
*/
virtual void rowMax(Matrix& max) {
(void)max;
LOG(FATAL) << "Not implemeted";
}
/**
* set the max of each column of this to mat
*/
virtual void colMax(Matrix& max) { LOG(FATAL) << "not implemented"; }
/**
* @brief Get the top k elements of each column of this matrix.
*
* The row ids and values of these elements are stored in
* maxIds and max respectively. where k is the size of maxIds.
* And note that the top k elements are not sorted.
*/
virtual void colMax(IVector& maxIds, Matrix& maxVal) {
LOG(FATAL) << "not implemented";
}
virtual void maxoutForward(Matrix& a,
IVector& id,
size_t channels,
size_t groups) {
LOG(FATAL) << "not implemented";
}
virtual void maxoutBackward(Matrix& a,
IVector& id,
size_t channels,
size_t groups) {
LOG(FATAL) << "not implemented";
}
virtual void rowMaxId(IVector& maxIds) { LOG(FATAL) << "Not implemented"; }
/**
* @brief Get the top k elements of each row of this matrix.
*
* The column ids and values of these elements are stored in
* maxIds and max respectively. where k is the size of maxIds.
* And note that the top k elements are not sorted.
*/
virtual void rowMax(IVector& maxIds, Matrix& max) {
LOG(FATAL) << "Not implemented";
}
/// normalize each row so that the sum of each row is 1.
virtual void rowNormalizeL1(Matrix& out) {
(void)out;
LOG(FATAL) << "Not implemeted";
}
/**
* @code
* this = a*b
* @endcode
*/
virtual void mul(const Matrix& a, const Matrix& b) {
LOG(FATAL) << "Not implemented";
}
/**
* @code
* this = scaleAB*(this*b) + scaleT*this
* @endcode
*/
virtual void rightMul(Matrix& b, real scaleAB, real scaleT) {
LOG(FATAL) << "Not implemented";
}
/**
* @code
* this = this* b
* @endcode
*/
virtual void rightMul(Matrix& b) { LOG(FATAL) << "Not implemented"; }
/**
* @code
* this = scaleAB*(a*this) + scaleT*this
* @endcode
*/
virtual void leftMul(Matrix& a, real scaleAB, real scaleT) {
LOG(FATAL) << "Not implemented";
}
/**
* @code
* this = a*this)
* @endcode
*/
virtual void leftMul(Matrix& a) { LOG(FATAL) << "Not implemented"; }
/// merge the element for each col.
virtual void colMerge(Matrix& src) { LOG(FATAL) << "Not implemented"; }
/// copy -log(output[label]) to this->data[i].
virtual void oneHotCrossEntropy(Matrix& output, IVector& label) {
LOG(FATAL) << "Not implemented";
}
/// calculate the error of outputV according to label.
virtual void oneHotCrossEntropyBp(Matrix& outputV, IVector& label) {
LOG(FATAL) << "Not implemented";
}
/// copy -log(output[label]) to this->data[i].
virtual void oneHotCrossEntropyWithSelfNorm(Matrix& output,
IVector& label,
real alpha) {
LOG(FATAL) << "Not implemented";
}
/// calculate the error of outputV according to label.
virtual void oneHotCrossEntropyWithSelfNormBp(Matrix& outputV,
IVector& label,
real alpha) {
LOG(FATAL) << "Not implemented";
}
/**
* \f[
* a[i] = \sum_{j=-(N-1)/2}^{(N-1)/2} b_{i+j} * c_{j}
* \f]
*
* b contains M elements,
* c contains N elements (N is odd),
* b's index arithmetic is computed modulo M,
* c's index arithmetic is computed modulo N.
*/
virtual void circularConv(Matrix& b, Matrix& c) {
LOG(FATAL) << "Not implemented";
}
virtual void circularConvDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2) {
LOG(FATAL) << "Not implemented";
}
/* output_ij = exp(this_{ij}) / (sum_j exp(this_ij)) */
virtual void softmax(Matrix& output) {
(void)output;
LOG(FATAL) << "Not implemeted";
}
virtual void sequenceSoftmax(Matrix& output, const IVector& index) {
(void)output;
LOG(FATAL) << "Not implemeted";
}
virtual void softmaxBackward(Matrix& outputV) {
(void)outputV;
LOG(FATAL) << "Not implemeted";
}
/*
sum_i = sum_j this_ij * output_ij
this_ij = output_ij* (this_ij - sum_i)
*/
virtual void softmaxDerivative(Matrix& output, Matrix& sftmaxSum) {
LOG(FATAL) << "Not implemented";
}
/// calculate the sum of squares diff cost.
virtual void sumOfSquares(Matrix& output, Matrix& label) {
LOG(FATAL) << "Not implemented";
}
/// gradient of sumOfSquares.
virtual void sumOfSquaresBp(Matrix& outputV, Matrix& label) {
LOG(FATAL) << "Not implemented";
}
virtual void tanh(Matrix& output) { LOG(FATAL) << "Not implemented"; }
virtual void tanhDerivative(Matrix& output) {
LOG(FATAL) << "Not implemented";
}
virtual void softrelu(Matrix& output) { LOG(FATAL) << "Not implemented"; }
virtual void softreluDerivative(Matrix& output) {
LOG(FATAL) << "Not implemented";
}
virtual void scaledTanh(Matrix& output, real p1, real p2) {
LOG(FATAL) << "Not implemented";
}
/**
* cosine similarity, for each row i,
* this[i] = cos(output1[i], output2[i])
*
* output2 can only have one row, then for each row i,
* this[i] = cos(output1[i], output2[0])
*/
virtual void cosSim(Matrix& output1, Matrix& output2, real scale = 1.0f) {
LOG(FATAL) << "Not implemented";
}
virtual void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale = 1.0f) {
LOG(FATAL) << "Not implemented";
}
/// print out the values of elements to os
virtual void print(std::ostream& os) const {
LOG(FATAL) << "Not implemented";
}
/**
* print a part of the matrix
* from the (top,left) value to the (height, width) value (not included)
*/
virtual void print(std::ostream& os, size_t height, size_t width) const {
LOG(FATAL) << "Not implemented";
}
/// print one row to os
virtual void printOneRow(std::ostream& os, size_t idx) const {
LOG(FATAL) << "Not implemented";
}
virtual void check(std::ostream& os, Matrix& refMat, bool printDiff = true) {}
virtual real getMin() {
LOG(FATAL) << "Not implemented";
return 0;
}
virtual real getMax() {
LOG(FATAL) << "Not implemented";
return 0;
}
virtual void randomizeUniform() { LOG(FATAL) << "Not implemented"; }
/**
* @brief calulate the error of classification
*
* output[i] = 1 if row i is an error.
*
* output[i] = 0 if row i is correct.
*/
virtual void classificationError(Matrix& output, IVector& label) {
LOG(FATAL) << "Not implemented";
}
/**
* This function is used to calculate the convolution:
*
* It will expand a feature matrix according to the
* convolution filters
*/
virtual void convExpand(Matrix& feature,
int feaImgHeight,
int feaImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW) {
LOG(FATAL) << "Not implemeted";
}
/**
* This function is the reverse implementation of convExpand:
*
* Its function is to restore a expanded-matrix into a feature matrix
*/
virtual void convShrink(Matrix& expandColMat,
int thisImgHeight,
int thisImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW,
real alpha = 1.0f,
real beta = 0.0f) {
LOG(FATAL) << "Not implemeted";
}
/**
* Pooling forward operation, pick out the largest element
* in the sizeX of value
*/
virtual void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW) {
LOG(FATAL) << "Not implemeted";
}
/// Pooling backward operation.
virtual void maxPoolBackward(Matrix& image,
size_t imgSizeH,
size_t imgSizeW,
Matrix& outGrad,
Matrix& outV,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW) {
LOG(FATAL) << "Not implemeted";
}
/// Pooling forward operation, caculate the average of sizeX elements.
virtual void avgPoolForward(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW) {
LOG(FATAL) << "Not implemeted";
}
virtual void avgPoolBackward(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW) {
LOG(FATAL) << "Not implemeted";
}
/**
* Input: one or more sequences. Each sequence contains some instances.
*
* Output: output size is the number of input sequences (NOT input
* instances).
*
* output[i] is set to max_input[i].
*/
virtual void maxSequenceForward(Matrix& input,
const IVector& sequence,
IVector& index) {
LOG(FATAL) << "Not implemeted";
}
virtual void maxSequenceBackward(Matrix& outputGrad,
const IVector& sequence,
IVector& index) {
LOG(FATAL) << "Not implemeted";
}
/**
* @code
* this.row[i] += table.row[ids[i]]
* if ids[i] == -1, it will be ignored
* @endcode
*/
virtual void selectRows(Matrix& table, IVector& ids) {
(void)table;
(void)ids;
LOG(FATAL) << "Not implemented";
}
/**
* @code
* this[i] = table[i, id[i]]
* @endcode
*/
virtual void selectElements(Matrix& table, IVector& ids) {
LOG(FATAL) << "Not implemented";
}
/**
* @code
* table.row[ids[i]] += this.row[i]
* if ids[i] == -1, it will be ignored
* @endcode
*/
virtual void addToRows(Matrix& table, IVector& ids) {
(void)table;
(void)ids;
LOG(FATAL) << "Not implemented";
}
/**
* @code
* table[i, id[i]] += this[i]
* @endcode
*/
virtual void addElements(Matrix& table, IVector& ids) {
LOG(FATAL) << "Not implemented";
}
/**
* @brief cross entropy for multi binary labels
*
* @code
* this[i] = -sum(label[i][j]*log(output[i][j])
* + (1-label[i][j])*log(1-output[i][j]))
* @endcode
*/
virtual void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) {
LOG(FATAL) << "Not implemented";
}
/**
* @brief The gradient of cross entropy for multi binary labels on output
*
* @code
* this[i][j] = -label[i][j]/output[i][j]
* + (1-label[i][j])/(1-output[i][j])
* @endcode
*/
virtual void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) {
LOG(FATAL) << "Not implemented";
}
/**
* @brief Calculate the classification error for multi binary labels
*
* @code
* this[i] = sum((output[i][j] >= threshold && label[i][j] == 0)
* || (output[i][j] < threshold && label[i][j] == 1))
* / output->getWidth()
* @endcode
*/
virtual void classificationErrorMulti(Matrix& output,
Matrix& label,
real threshold) {
LOG(FATAL) << "Not implemented";
}
virtual void paramReluForward(Matrix& data, Matrix& W) {
LOG(FATAL) << "Not implemented";
}
virtual void paramReluBackwardW(Matrix& oGrad, Matrix& data) {
LOG(FATAL) << "Not implemented";
}
virtual void paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) {
LOG(FATAL) << "Not implemented";
}
virtual void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
LOG(FATAL) << "Not implemented";
}
virtual void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW) {
LOG(FATAL) << "Not implemented";
}
template <typename ExpressionType>
void operator=(const ExpressionType& expr) {
if (useGpu_) {
TensorGpuApply<real>(*this, expr);
} else {
TensorCpuApply<real>(*this, expr);
}
}
bool isEmpty() const {
return data_ == nullptr;
}
explicit operator bool() const {
return !isEmpty();
}
};
inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) {
mat.print(os);
return os;
}
class GpuMatrix : public Matrix {
public:
GpuMatrix();
GpuMatrix(size_t height, size_t width, bool trans = false);
GpuMatrix(real* data, size_t height, size_t width, bool trans = false)
: Matrix(data, height, width, trans, true) {}
GpuMatrix(real* data,
size_t height,
size_t width,
size_t stride,
bool trans = false)
: Matrix(data, height, width, stride, trans, true) {}
GpuMatrix(GpuMemHandlePtr dataHandle,
size_t height,
size_t width,
bool trans = false)
: Matrix(dataHandle, height, width, trans, true) {}
~GpuMatrix();
void zeroMem();
void resetOne();
void setDiag(real value);
void resize(size_t newHeight, size_t newWidth);
void resize(size_t newHeight,
size_t newWidth,
size_t newNnz, /* used to allocate space */
SparseValueType valueType,
SparseFormat format) {
LOG(FATAL) << "Only Support Sparse Matrix";
}
void setRow(size_t row,
size_t colNum,
const unsigned int* cols,
const real* values) {
LOG(FATAL) << "Only Support Sparse Matrix";
}
/**
* Copy the data from cpu_memory buffer
*/
void copyFrom(const real* hostSrc, size_t size);
void copyFrom(const real* hostSrc, const int64_t* seq);
void copyFrom(const Matrix& src, hl_stream_t stream);
void copyFrom(const Matrix& src);
void copyFrom(const IVector& src);
void copyByRowIndex(Matrix& b, const IVector& rowIndex);
MatrixPtr clone(size_t height, size_t width, bool useGpu = false);
real getElement(size_t x, size_t y) const;
real* getRow(size_t row) { return BaseMatrix::rowBuf(row); }
virtual real* getRowBuf(size_t row) { return getRow(row); }
real getSum();
void accumulateColSum(Matrix& src);
real getAbsSum();
MatrixPtr getTranspose();
void transpose(MatrixPtr matTrans, bool memAlloc);
MatrixPtr getInverse();
void inverse(MatrixPtr matInv, bool memAlloc);
/// add b to each sample of this.
void addBias(Matrix& b, real scale);
void addSharedBias(Matrix& b, real scale);
/**
* @code
* add each sample from a to this.
* @endcode
*/
void collectBias(Matrix& a, real scale);
void collectSharedBias(Matrix& a, real scale);
void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode);
/**
* @code
* this.row[i] += table.row[ids[i]]
* @endcode
*/
virtual void selectRows(Matrix& table, IVector& ids);
/**
* @code
* this[i] = table[i, id[i]]
* @endcode
*/
virtual void selectElements(Matrix& table, IVector& ids);
/**
* @code
* table.row[ids[i]] += this.row[i]
* @endcode
*/
virtual void addToRows(Matrix& table, IVector& ids);
void addColumnVector(const Matrix& b);
/**
* @code
* this = scaleAB*(a*b) + scaleT*this
* @endcode
*/
void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT);
/**
* @code
* this = a*b
* @endcode
*/
void mul(const Matrix& a, const Matrix& b);
void mul(const GpuMatrix& a, const GpuMatrix& b, real scaleAB, real scaleT);
void mul(const GpuSparseMatrix& a,
const GpuMatrix& b,
real scaleAB,
real scaleT);
void mul(const GpuMatrix& a,
const GpuSparseMatrix& b,
real scaleAB,
real scaleT);
/**
* @code
* this = scaleAB*(this*b) + scaleT*this
* @endcode
*/
void rightMul(Matrix& b, real scaleAB, real scaleT);
/**
* @code
* this = this* b
* @endcode
*/
void rightMul(Matrix& b);
/**
* @code
* this = scaleAB*(a*this) + scaleT*this
* @endcode
*/
void leftMul(Matrix& a, real scaleAB, real scaleT);
/**
* @code
* this = a*this
* @endcode
*/
void leftMul(Matrix& a);
void colMerge(Matrix& src);
void rowSum(Matrix& sum);
void rowMax(Matrix& max);
void rowMax(IVector& maxIds, Matrix& max);
void colMax(Matrix& max);
void colMax(IVector& maxIds, Matrix& max);
void maxoutForward(Matrix& a, IVector& id, size_t channels, size_t groups);
void maxoutBackward(Matrix& a, IVector& id, size_t channels, size_t groups);
void oneHotCrossEntropy(Matrix& output, IVector& label);
void oneHotCrossEntropyBp(Matrix& outputV, IVector& label);
void oneHotCrossEntropyWithSelfNorm(Matrix& output,
IVector& label,
real alpha);
void oneHotCrossEntropyWithSelfNormBp(Matrix& outputV,
IVector& label,
real alpha);
void softmax(Matrix& output);
void sequenceSoftmax(Matrix& output, const IVector& index);
void softmaxBackward(Matrix& outputV);
void softmaxDerivative(Matrix& output, Matrix& sftmaxSum);
/// calculate the sum of squares diff cost.
void sumOfSquares(Matrix& output, Matrix& label);
/// gradient of sumOfSquares.
void sumOfSquaresBp(Matrix& outputV, Matrix& label);
void tanh(Matrix& output);
void tanhDerivative(Matrix& output);
void softrelu(Matrix& output);
void softreluDerivative(Matrix& output);
void scaledTanh(Matrix& output, real p1, real p2);
void cosSim(Matrix& output1, Matrix& output2, real scale);
void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale);
virtual void print(std::ostream& os) const;
virtual void print(std::ostream& os, size_t height, size_t width) const;
void paramReluForward(Matrix& data, Matrix& W);
void paramReluBackwardW(Matrix& oGrad, Matrix& data);
void paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W);
void check(std::ostream& os, Matrix& refMat, bool printDiff = true);
void randomizeUniform();
void classificationError(Matrix& output, IVector& label);
void convExpand(Matrix& feature,
int feaImgHeight,
int feaImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW);
void convShrink(Matrix& expandColMat,
int thisImgHeight,
int thisImgWidth,
int channels,
int blockH,
int blochW,
int strideH,
int strideW,
int paddingH,
int paddingWreal,
int outputH,
int outputW,
real alpha = 1.0f,
real beta = 0.0f);
void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW);
void maxPoolBackward(Matrix& image,
size_t imgSizeH,
size_t imgSizeW,
Matrix& outGrad,
Matrix& outV,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW);
void avgPoolForward(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW);
void avgPoolBackward(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW);
void maxSequenceForward(Matrix& input,
const IVector& sequence,
IVector& index);
void maxSequenceBackward(Matrix& outputGrad,
const IVector& sequence,
IVector& index);
void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
template <typename ExpressionType>
void operator=(const ExpressionType& expr) {
TensorGpuApply<real>(*this, expr);
}
};
class CpuMatrix : public Matrix {
public:
CpuMatrix(size_t height, size_t width, bool trans = false);
CpuMatrix(real* data, size_t height, size_t width, bool trans = false)
: Matrix(data, height, width, trans, false) {}
CpuMatrix(real* data,
size_t height,
size_t width,
size_t stride,
bool trans = false)
: Matrix(data, height, width, stride, trans, false) {}
CpuMatrix(CpuMemHandlePtr dataHandle,
size_t height,
size_t width,
bool trans = false)
: Matrix(dataHandle, height, width, trans, false) {}
~CpuMatrix();
void zeroMem();
void resetOne();
void setDiag(real value);
void resize(size_t newHeight, size_t newWidth);
void resize(size_t newHeight,
size_t newWidth,
size_t newNnz, /* used to allocate space */
SparseValueType valueType,
SparseFormat format) {
LOG(FATAL) << "Only Support Sparse Matrix";
}
void setRow(size_t row,
size_t colNum,
const unsigned int* cols,
const real* values) {
LOG(FATAL) << "Only Support Sparse Matrix";
}
real getElement(size_t x, size_t y) const;
real getSum();
void accumulateColSum(Matrix& src);
real getAbsSum();
MatrixPtr getTranspose();
void transpose(MatrixPtr matTrans, bool memAlloc);
MatrixPtr getInverse();
void inverse(MatrixPtr matInv, bool memAlloc);
void copyFrom(const Matrix& src);
void copyFrom(const Matrix& src, hl_stream_t stream);
void copyFrom(const real* cpuSrc, size_t size);
void copyFrom(const real* cpuSrc, const int64_t* seq);
void copyFrom(const IVector& src);
void copyFrom(CpuSparseMatrix& src);
void copyByRowIndex(Matrix& b, const IVector& rowIndex);
MatrixPtr clone(size_t height, size_t width, bool useGpu = false);
void convExpand(Matrix& feature,
int feaImgHeight,
int feaImgWidth,
int channels,
int blcokH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW);
void convShrink(Matrix& expandFeat,
int thisImgHeight,
int thisImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW,
real alpha = 1.0f,
real beta = 0.0f);
void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW);
void maxPoolBackward(Matrix& image,
size_t imgSizeH,
size_t imgSizeW,
Matrix& outGrad,
Matrix& outV,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW);
void avgPoolForward(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW);
void avgPoolBackward(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
size_t sizeX,
size_t sizeY,
size_t strideH,
size_t strideW,
size_t outputH,
size_t outputW,
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW);
void maxSequenceForward(Matrix& input,
const IVector& sequence,
IVector& index);
void maxSequenceBackward(Matrix& outputGrad,
const IVector& sequence,
IVector& index);
real* getRow(size_t row) { return BaseMatrix::rowBuf(row); }
virtual real* getRowBuf(size_t row) { return getRow(row); }
public:
/// add b to each sample of this.
void addBias(Matrix& b, real scale);
void addSharedBias(Matrix& b, real scale);
/// add each sample of a to this.
void collectBias(Matrix& a, real scale);
void collectSharedBias(Matrix& a, real scale);
void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode);
/**
* @code
* this.row[i] += table.row[ids[i]]
* @endcode
*/
virtual void selectRows(Matrix& table, IVector& ids);
/**
* @code
* table.row[ids[i]] += this.row[i]
* @endcode
*/
virtual void addToRows(Matrix& table, IVector& ids);
/**
* @code
* this[i] = table[i, id[i]]
* @endcode
*/
virtual void selectElements(Matrix& table, IVector& ids);
/**
* @code
* table[i, id[i]] += this[i]
* @endcode
*/
virtual void addElements(Matrix& table, IVector& ids);
/**
* use abstract getRow() to get row from table.
*
* Define table as template instead of virtual class for performance sake.
* internal used by above two virtual funcs.
*/
template <typename TableMatType>
void selectRowsImp(TableMatType& table, IVector& ids);
template <typename TableMatType>
void addToRowsImp(TableMatType& table, IVector& ids);
void addColumnVector(const Matrix& b);
void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT);
void mul(CpuMatrix* a, CpuMatrix* b, real scaleAB, real scaleT);
void mul(CpuMatrix* a, CpuSparseMatrix* b, real scaleAB, real scaleT);
static void mul(CpuMatrix* a,
CpuMatrix* b,
CpuSparseMatrix* c,
real scaleAB,
real scaleT);
/**
* c = a * b
*
* use abstract getRow() to get row from B,C.
* Define B,C as template instead of virtual class for performance sake.
*/
template <typename MatBType, typename MatCType>
static void mul(
CpuSparseMatrix* a, MatBType* b, MatCType* c, real scaleAB, real scaleT);
virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT);
void mul(const Matrix& a, const Matrix& b);
void rightMul(Matrix& b, real scaleAB, real scaleT);
void rightMul(Matrix& b);
void leftMul(Matrix& a, real scaleAB, real scaleT);
void leftMul(Matrix& a);
void colMerge(Matrix& src);
void rowSum(Matrix& sum);
void rowMaxId(IVector& maxIds);
void rowMax(Matrix& max);
void rowMax(IVector& maxIds, Matrix& maxVal);
void colMax(Matrix& max);
void colMax(IVector& maxIds, Matrix& maxVal);
void maxoutForward(Matrix& a, IVector& id, size_t channels, size_t groups);
void maxoutBackward(Matrix& a, IVector& id, size_t channels, size_t groups);
void rowNormalizeL1(Matrix& out);
void oneHotCrossEntropy(Matrix& output, IVector& label);
void oneHotCrossEntropyBp(Matrix& outputV, IVector& label);
void oneHotCrossEntropyWithSelfNorm(Matrix& output,
IVector& label,
real alpha);
void oneHotCrossEntropyWithSelfNormBp(Matrix& outputV,
IVector& label,
real alpha);
void circularConv(Matrix& b, Matrix& c);
void circularConvDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2);
void softmax(Matrix& output);
void sequenceSoftmax(Matrix& output, const IVector& index);
void softmaxDerivative(Matrix& output, Matrix& sftmaxSum);
/// calculate the sum of squares diff cost.
void sumOfSquares(Matrix& output, Matrix& label);
/// gradient of sumOfSquares.
void sumOfSquaresBp(Matrix& outputV, Matrix& label);
void tanh(Matrix& output);
void tanhDerivative(Matrix& output);
void softrelu(Matrix& output);
void softreluDerivative(Matrix& output);
void scaledTanh(Matrix& output, real p1, real p2);
void cosSim(Matrix& output1, Matrix& output2, real scale);
void cosSimDerivative(Matrix& output,
Matrix& prevOut1,
Matrix& prevOut2,
Matrix& prevGrad1,
Matrix& prevGrad2,
real scale);
void print(std::ostream& os) const;
void print(std::ostream& os, size_t height, size_t width) const;
void printOneRow(std::ostream& os, size_t idx) const;
void paramReluForward(Matrix& data, Matrix& W);
void paramReluBackwardW(Matrix& oGrad, Matrix& data);
void paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W);
void check(std::ostream& os, Matrix& refMat, bool printDiff = true);
real getMin();
real getMax();
void randomizeUniform();
void classificationError(Matrix& output, IVector& label);
void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec);
void addByBitCodeBackward(size_t numClasses,
const IVector& codes,
Matrix& vec);
void mulByBitCode(size_t numClasses,
const IVector& codes,
const Matrix& mat,
const Matrix& input);
void mulByBitCodeBackwardWeight(size_t numClasses,
const IVector& codes,
Matrix& mat,
const Matrix& input);
void mulByBitCodeBackwardError(size_t numClasses,
const IVector& codes,
const Matrix& mat,
Matrix& input);
void sumByBitCode(size_t numClasses,
IVector& codes,
Matrix& sum,
real scaleSum);
void subByBitCode(size_t numClasses_, IVector& codes);
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
void classificationErrorMulti(Matrix& output, Matrix& label, real threshold);
void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
const size_t outImgH,
const size_t outImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
void bilinearBackward(const Matrix& out,
const size_t outImgH,
const size_t outImgW,
const size_t inImgH,
const size_t inImgW,
const size_t numChannels,
const real ratioH,
const real ratioW);
template <typename ExpressionType>
void operator=(const ExpressionType& expr) {
TensorCpuApply<real>(*this, expr);
}
};
class SharedCpuMatrix : public CpuMatrix {
public:
/* blockNum is number of partitions of the matrix */
SharedCpuMatrix(int blockNum, size_t height, size_t width, bool trans = false)
: CpuMatrix(height, width, trans) {
initShared(blockNum);
}
SharedCpuMatrix(
int blockNum, real* data, size_t height, size_t width, bool trans = false)
: CpuMatrix(data, height, width, trans) {
initShared(blockNum);
}
SharedCpuMatrix(int blockNum,
CpuMemHandlePtr dataHandle,
size_t height,
size_t width,
bool trans = false)
: CpuMatrix(dataHandle, height, width, trans) {
initShared(blockNum);
}
SharedCpuMatrix(CpuMemHandlePtr dataHandle,
size_t height,
size_t width,
bool trans = false)
: CpuMatrix(dataHandle, height, width, trans) {
initBlock(1);
}
~SharedCpuMatrix() {}
public:
virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT);
virtual void add(Matrix& b, real p1, real p2);
virtual void add(real p1, real p2);
private:
using Matrix::mul;
void initShared(int blockNum);
void initBlock(int blockNum);
int blockNum_;
std::vector<std::unique_ptr<std::mutex>> blockLocks_;
ThreadLocal<CpuMatrixPtr> localBuf_;
ThreadLocal<std::vector<int>> localBufRows_;
ThreadLocal<std::vector<int>> blockSeq_;
};
typedef struct { unsigned int col; } sparse_non_value_t;
typedef struct {
unsigned int col;
float value;
} sparse_float_value_t;
} // namespace paddle
#include "ExecViaCpu.h"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册