提交 7df67bae 编写于 作者: H hedaoyuan 提交者: GitHub

Merge pull request #1064 from hedaoyuan/buffer

Add BufferArg as the Function argument type and modify the Function prototype to remove the inouts argument.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include "BufferArg.h"
namespace paddle {
const SequenceArg& BufferArg::sequence() const {
// CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA);
return dynamic_cast<const SequenceArg&>(*this);
}
const SparseMatrixArg& BufferArg::sparse() const {
// CHECK_EQ(bufferType_, TENSOR_SPARSE);
return dynamic_cast<const SparseMatrixArg&>(*this);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
#include "TensorShape.h"
#include "TensorType.h"
#include "paddle/math/CpuSparseMatrix.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
namespace paddle {
enum BufferType {
TENSOR_NORMAL = 0,
TENSOR_SEQUENCE_ID = 1,
TENSOR_SEQUENCE_DATA = 2,
TENSOR_SPARSE = 3
};
enum SparseDataType {
SPARSE_NO_VALUE = 0, // do not need value pointer, all values are 1
SPARSE_FLOAT_VALUE = 1
};
enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
class BufferArg;
class SequenceArg;
class SparseMatrixArg;
typedef std::shared_ptr<BufferArg> BufferArgPtr;
/**
* \brief BufferArg used as the argument type of Function.
*
* The arguments of the Paddle Function have four Buffer types.
* 1. BufferArg for a dense Buffer of any dimension.
* 2. SequenceIdArg for a Buffer of sequence start positions.
* 3. SequenceArg for a Buffer of sequence data.
* 4. SparseMatrixArg for a Buffer of sparse matrix.
*
* There is an ArgType property for the BufferArg used as Function Output.
* Whether the result of the Function calculation is assigned to the
* output Buffer or added to the output Buffer is determined by the
* argType_ property of the output BufferArg.
*/
// ArgType is only used by output BufferArg.
// For input argument, argType_ is ignored.
// For output argument, need to set the argType_ of the BufferArg.
enum ArgType {
UNSPECIFIED = 0,
ASSIGN_TO = 1,
ADD_TO = 2,
};
class BufferArg {
public:
void setArgType(ArgType argType) { argType_ = argType; }
ArgType getArgType() const { return argType_; }
public:
BufferArg(void* buf,
ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {}
BufferArg(void* buf, ValueType valueType)
: buf_(buf), valueType_(valueType) {}
BufferArg(const Matrix& matrix, ArgType argType = UNSPECIFIED)
: buf_(
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
valueType_(DataType<real>::value),
shape_(2),
argType_(argType) {
shape_.setDim(0, matrix.getHeight());
shape_.setDim(1, matrix.getWidth());
}
BufferArg(const Matrix& matrix,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: buf_(
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
valueType_(DataType<real>::value),
shape_(shape),
argType_(argType) {
CHECK_EQ(matrix.getElementCnt(), shape.getElements());
}
BufferArg(const Vector& vector, ArgType argType = UNSPECIFIED)
: buf_(
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
valueType_(DataType<real>::value),
shape_(1),
argType_(argType) {
shape_.setDim(0, vector.getSize());
}
BufferArg(const IVector& vector, ArgType argType = UNSPECIFIED)
: buf_(
const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
valueType_(VALUE_TYPE_INT32),
shape_(1),
argType_(argType) {
shape_.setDim(0, vector.getSize());
}
template <DeviceType DType>
typename Tensor<real, DType>::Matrix matrix() const {
CHECK(buf_);
CHECK(valueType_ == DataType<real>::value);
// CHECK(deviceType_ == DType);
CHECK_EQ((size_t)2, shape_.ndims());
return typename Tensor<real, DType>::Matrix(
reinterpret_cast<real*>(buf_), shape_[0], shape_[1]);
}
template <typename VType, DeviceType DType>
typename Tensor<VType, DType>::Vector vector() const {
CHECK(buf_);
CHECK(valueType_ == DataType<VType>::value);
// CHECK(deviceType_ == DType);
CHECK_EQ((size_t)1, shape_.ndims());
return typename Tensor<VType, DType>::Vector(
shape_[0], reinterpret_cast<VType*>(buf_));
}
virtual ~BufferArg() {}
template <typename T>
T* data() const {
return reinterpret_cast<T*>(buf_);
}
void* data() const { return buf_; }
ValueType valueType() const { return valueType_; }
BufferType bufferType() const { return bufferType_; }
const TensorShape& shape() const { return shape_; }
const SequenceArg& sequence() const;
const SparseMatrixArg& sparse() const;
protected:
void* buf_;
ValueType valueType_;
TensorShape shape_;
BufferType bufferType_;
ArgType argType_ = UNSPECIFIED;
// leading dimensions. The size is dims_.size()
// Dims lds_;
};
// sequence start positions in a mini-batch of sequences
// shape_.ndims() == 1
// valueType_ = int32
// if a < b then value_.buf_[a] < value_.buf_[b]
class SequenceIdArg : public BufferArg {
public:
SequenceIdArg(void* buf,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: BufferArg(buf, VALUE_TYPE_INT32, shape, argType) {
CHECK_EQ(shape_.ndims(), (size_t)1);
numSeqs_ = shape_[0] - 1;
}
SequenceIdArg(const IVector& vector) : BufferArg(vector) {
numSeqs_ = shape_[0] - 1;
}
~SequenceIdArg() {}
size_t numSeqs() const { return numSeqs_; }
private:
size_t numSeqs_;
};
// sequence data
class SequenceArg : public BufferArg {
public:
SequenceArg(void* buf,
ValueType valueType,
const TensorShape& shape,
const SequenceIdArg& startPositions,
ArgType argType = UNSPECIFIED)
: BufferArg(buf, valueType, shape, argType),
startPositions_(startPositions) {}
SequenceArg(const Matrix& matrix,
const IVector& vector,
ArgType argType = UNSPECIFIED)
: BufferArg(matrix, argType), startPositions_(vector) {}
~SequenceArg() {}
void* getIdBuf() const { return startPositions_.data(); }
size_t numSeqs() const { return startPositions_.numSeqs(); }
private:
SequenceIdArg startPositions_;
};
// sparse matrix
// valueType_ == float or double
// shape_.ndims() == 2
class SparseMatrixArg : public BufferArg {
public:
SparseMatrixArg(void* buf,
ValueType valueType,
const TensorShape& shape,
const BufferArg& row,
const BufferArg& col,
size_t nnz,
SparseDataFormat format,
SparseDataType type,
ArgType argType = UNSPECIFIED)
: BufferArg(buf, valueType, shape, argType),
row_(row),
col_(col),
nnz_(nnz),
format_(format),
type_(type) {
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
CHECK_EQ(shape_.ndims(), (size_t)2);
CHECK_EQ(row_.shape().ndims(), (size_t)1);
CHECK_EQ(col_.shape().ndims(), (size_t)1);
if (format == SPARSE_CSR_FORMAT) {
CHECK_EQ(nnz, col.shape()[0]);
} else if (format == SPARSE_CSC_FORMAT) {
CHECK_EQ(nnz, row.shape()[0]);
}
}
SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED)
: BufferArg(sparse, argType),
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED)
: BufferArg(sparse, argType),
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
~SparseMatrixArg() {}
void* getRowBuf() const { return row_.data(); }
void* getColBuf() const { return col_.data(); }
size_t nnz() const { return nnz_; }
SparseDataFormat dataFormat() const { return format_; }
SparseDataType dataType() const { return type_; }
private:
BufferArg row_;
BufferArg col_;
size_t nnz_;
SparseDataFormat format_;
SparseDataType type_;
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "BufferArg.h"
#include <gtest/gtest.h>
#include "Function.h"
#include "paddle/math/MemoryHandle.h"
namespace paddle {
TEST(BufferTest, BufferArg) {
TensorShape shape({8, 10});
CpuMemoryHandle memory(shape.getElements() *
sizeOfValuType(VALUE_TYPE_FLOAT));
BufferArg buffer(memory.getBuf(), VALUE_TYPE_FLOAT, shape);
EXPECT_EQ(buffer.data(), memory.getBuf());
}
TEST(BufferTest, SequenceIdArg) {
TensorShape shape({10});
CpuMemoryHandle memory(shape.getElements() *
sizeOfValuType(VALUE_TYPE_INT32));
SequenceIdArg buffer(memory.getBuf(), shape);
EXPECT_EQ(buffer.data(), memory.getBuf());
EXPECT_EQ(buffer.numSeqs(), 9);
}
TEST(BufferTest, asArgument) {
MatrixPtr matrix = Matrix::create(100, 200);
VectorPtr vector = Vector::create(100, false);
CpuSparseMatrix sparse(200, 300, 50);
// prepare arguments
BufferArgs argments;
argments.addArg(*matrix);
argments.addArg(*vector);
argments.addArg(sparse);
// function
auto function = [=](const BufferArgs& inputs) {
EXPECT_EQ(inputs.size(), 3);
// check inputs[0]
EXPECT_EQ(inputs[0].shape().ndims(), 2);
EXPECT_EQ(inputs[0].shape()[0], 100);
EXPECT_EQ(inputs[0].shape()[1], 200);
EXPECT_EQ(inputs[0].data(), matrix->getData());
EXPECT_EQ(inputs[0].matrix<DEVICE_TYPE_CPU>().getHeight(),
matrix->getHeight());
EXPECT_EQ(inputs[0].matrix<DEVICE_TYPE_CPU>().getWidth(),
matrix->getWidth());
EXPECT_EQ(inputs[0].matrix<DEVICE_TYPE_CPU>().getData(), matrix->getData());
// check inputs[1]
EXPECT_EQ(inputs[1].shape().ndims(), 1);
EXPECT_EQ(inputs[1].shape()[0], 100);
EXPECT_EQ(inputs[1].data(), vector->getData());
CpuVector inVector = inputs[1].vector<real, DEVICE_TYPE_CPU>();
EXPECT_EQ(inVector.getSize(), vector->getSize());
EXPECT_EQ(inVector.getData(), vector->getData());
// check inputs[2]
EXPECT_EQ(inputs[2].shape().ndims(), 2);
EXPECT_EQ(inputs[2].shape()[0], 200);
EXPECT_EQ(inputs[2].shape()[1], 300);
EXPECT_EQ(inputs[2].data(), sparse.getData());
// CHECK_EQ(inputs[2].sparse().nnz(), 50);
// CHECK_EQ(inputs[2].sparse().dataFormat(), SPARSE_CSR_FORMAT);
// CHECK_EQ(inputs[2].sparse().dataType(), SPARSE_FLOAT_VALUE);
EXPECT_EQ(inputs[2].sparse().getRowBuf(), sparse.getRows());
EXPECT_EQ(inputs[2].sparse().getColBuf(), sparse.getCols());
};
// call function
function(argments);
}
} // namespace paddle
......@@ -3,6 +3,7 @@ file(GLOB cpp_files . *Op.cpp)
list(APPEND h_files Function.h)
list(APPEND cpp_files Function.cpp)
list(APPEND cpp_files BufferArg.cpp)
if(WITH_GPU)
file(GLOB cu_files . *OpGpu.cu)
......@@ -18,8 +19,12 @@ if(WITH_TESTING)
# TODO:
# file(GLOB test_files . *OpTest.cpp)
# add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files})
add_simple_unittest(CrossMapNormalOpTest)
add_simple_unittest(ContextProjectionOpTest)
# add_simple_unittest(CrossMapNormalOpTest)
add_simple_unittest(TensorShapeTest)
add_simple_unittest(TensorTypeTest)
add_simple_unittest(BufferArgTest)
add_simple_unittest(FunctionTest)
# add_simple_unittest(ContextProjectionOpTest)
endif()
endif()
......
......@@ -19,17 +19,15 @@ limitations under the License. */
namespace paddle {
template <>
void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix* out_mat,
const CpuMatrix* input_mat,
const CpuMatrix* weight_mat,
void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
const CpuMatrix& input_mat,
const CpuMatrix& weight_mat,
const CpuIVector& seq_vec,
size_t context_length,
int context_start,
size_t begin_pad) {
const int* starts = seq_vec.getData();
const size_t num_sequences = seq_vec.getSize() - 1;
auto w_mat = const_cast<CpuMatrix*>(weight_mat);
auto in_mat = const_cast<CpuMatrix*>(input_mat);
for (size_t i = 0; i < num_sequences; ++i) {
for (size_t j = 0; j < context_length; ++j) {
int begin = starts[i] + context_start + j;
......@@ -39,10 +37,11 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix* out_mat,
if (begin < starts[i]) {
int64_t pad_size =
std::min(starts[i] - begin, starts[i + 1] - starts[i]);
MatrixPtr mat = out_mat->subMatrix(starts[i], pad_size);
if (w_mat) {
MatrixPtr sub = w_mat->subMatrix(j, pad_size);
mat->addAtOffset(*sub, j * in_mat->getWidth());
MatrixPtr mat = out_mat.subMatrix(starts[i], pad_size);
if (weight_mat) {
MatrixPtr sub =
const_cast<CpuMatrix&>(weight_mat).subMatrix(j, pad_size);
mat->addAtOffset(*sub, j * input_mat.getWidth());
}
dst_begin = starts[i] + pad_size;
begin = starts[i];
......@@ -50,19 +49,22 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix* out_mat,
if (end > starts[i + 1]) {
int64_t pad_size =
std::min(end - starts[i + 1], starts[i + 1] - starts[i]);
MatrixPtr mat = out_mat->subMatrix(starts[i + 1] - pad_size, pad_size);
if (w_mat) {
MatrixPtr sub = w_mat->subMatrix(
begin_pad + context_start + j - pad_size, pad_size);
mat->addAtOffset(*sub, j * in_mat->getWidth());
MatrixPtr mat = out_mat.subMatrix(starts[i + 1] - pad_size, pad_size);
if (weight_mat) {
MatrixPtr sub =
const_cast<CpuMatrix&>(weight_mat)
.subMatrix(begin_pad + context_start + j - pad_size,
pad_size);
mat->addAtOffset(*sub, j * input_mat.getWidth());
}
dst_end = starts[i + 1] - pad_size;
end = starts[i + 1];
}
if (end <= begin) continue;
MatrixPtr src = in_mat->subMatrix(begin, end - begin);
MatrixPtr dst = out_mat->subMatrix(dst_begin, dst_end - dst_begin);
dst->addAtOffset(*src, j * in_mat->getWidth());
MatrixPtr src =
const_cast<CpuMatrix&>(input_mat).subMatrix(begin, end - begin);
MatrixPtr dst = out_mat.subMatrix(dst_begin, dst_end - dst_begin);
dst->addAtOffset(*src, j * input_mat.getWidth());
}
}
}
......@@ -82,40 +84,32 @@ public:
begin_pad_ = config.get<size_t>("begin_pad");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(3, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size()));
CHECK_EQ(0, static_cast<int>(inouts.size()));
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)3, inputs.size());
CHECK_EQ((size_t)1, outputs.size());
CHECK(outputs[0].getData() && inputs[0].getData() && inputs[2].getData());
CHECK_EQ(static_cast<int>(outputs[0].dims_.size()), 2);
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 2);
CHECK_EQ(static_cast<int>(inputs[1].dims_.size()), 2);
CHECK_EQ(static_cast<int>(inputs[2].dims_.size()), 1);
CHECK(outputs[0].data() && inputs[0].data() && inputs[2].data());
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[2].shape().ndims(), (size_t)1);
/// dim of output = dim of input * context_length
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_);
/// dim of input == dim of weight
CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]);
CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]);
/// input and output has the same batch_size
CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]);
auto out_mat = std::make_shared<typename MatrixT<Device>::type>(
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
const auto in_mat = std::make_shared<typename MatrixT<Device>::type>(
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
const auto w_mat =
!inputs[1].getData()
? nullptr
: std::make_shared<typename MatrixT<Device>::type>(
inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]);
typename SequenceT<Device>::type seq_vec(
inputs[2].dims_[0], reinterpret_cast<int*>(inputs[2].getData()));
ContextProjectionForward<Device>(out_mat.get(),
in_mat.get(),
w_mat.get(),
CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
auto out_mat = outputs[0].matrix<Device>();
auto in_mat = inputs[0].matrix<Device>();
auto w_mat = !inputs[1].data()
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: inputs[1].matrix<Device>();
auto seq_vec = inputs[2].vector<int, Device>();
ContextProjectionForward<Device>(out_mat,
in_mat,
w_mat,
seq_vec,
context_length_,
context_start_,
......@@ -129,18 +123,17 @@ private:
};
template <>
void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
CpuMatrix* in_grad_mat,
CpuMatrix* w_grad_mat,
void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix& out_grad_mat,
CpuMatrix& in_grad_mat,
CpuMatrix& w_grad_mat,
const CpuIVector& seq_vec,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad) {
CHECK(out_grad_mat);
size_t input_dim = in_grad_mat ? in_grad_mat->getWidth()
: w_grad_mat ? w_grad_mat->getWidth() : 0;
size_t input_dim = in_grad_mat ? in_grad_mat.getWidth()
: w_grad_mat ? w_grad_mat.getWidth() : 0;
const int* starts = seq_vec.getData();
size_t num_sequences = seq_vec.getSize() - 1;
for (size_t i = 0; i < num_sequences; ++i) {
......@@ -153,8 +146,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
int64_t pad_size =
std::min(starts[i] - begin, starts[i + 1] - starts[i]);
if (is_padding && w_grad_mat) {
MatrixPtr mat = out_grad_mat->subMatrix(starts[i], pad_size);
MatrixPtr sub = w_grad_mat->subMatrix(j, pad_size);
MatrixPtr mat = out_grad_mat.subMatrix(starts[i], pad_size);
MatrixPtr sub = w_grad_mat.subMatrix(j, pad_size);
sub->addAtOffset(*mat, j * input_dim);
}
dst_begin = starts[i] + pad_size;
......@@ -165,8 +158,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
std::min(end - starts[i + 1], starts[i + 1] - starts[i]);
if (is_padding && w_grad_mat) {
MatrixPtr mat =
out_grad_mat->subMatrix(starts[i + 1] - pad_size, pad_size);
MatrixPtr sub = w_grad_mat->subMatrix(
out_grad_mat.subMatrix(starts[i + 1] - pad_size, pad_size);
MatrixPtr sub = w_grad_mat.subMatrix(
begin_pad + context_start + j - pad_size, pad_size);
sub->addAtOffset(*mat, j * input_dim);
}
......@@ -175,8 +168,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix* out_grad_mat,
}
if (end <= begin) continue;
if (!in_grad_mat) continue;
MatrixPtr src = in_grad_mat->subMatrix(begin, end - begin);
MatrixPtr dst = out_grad_mat->subMatrix(dst_begin, dst_end - dst_begin);
MatrixPtr src = in_grad_mat.subMatrix(begin, end - begin);
MatrixPtr dst = out_grad_mat.subMatrix(dst_begin, dst_end - dst_begin);
src->addAtOffset(*dst, j * input_dim);
}
}
......@@ -199,44 +192,36 @@ public:
total_pad_ = config.get<size_t>("total_pad");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(3, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size()));
CHECK_EQ(0, static_cast<int>(inouts.size()));
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)3, inputs.size());
CHECK_EQ((size_t)1, outputs.size());
CHECK(outputs[0].getData() && inputs[2].getData());
CHECK_EQ(static_cast<int>(outputs[0].dims_.size()), 2);
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 2);
CHECK_EQ(static_cast<int>(inputs[1].dims_.size()), 2);
CHECK_EQ(static_cast<int>(inputs[2].dims_.size()), 1);
CHECK(outputs[0].data() && inputs[2].data());
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[2].shape().ndims(), (size_t)1);
/// dim of input == dim of weight
CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]);
CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]);
/// input and output has the same batch_size
CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]);
CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
/// dim of output = dim of input * context_length
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_);
auto out_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
auto in_grad_mat =
!inputs[0].getData()
? nullptr
: std::make_shared<typename MatrixT<Device>::type>(
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
auto w_grad_mat =
!inputs[1].getData()
? nullptr
: std::make_shared<typename MatrixT<Device>::type>(
inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]);
typename SequenceT<Device>::type seq_vec(
inputs[2].dims_[0], reinterpret_cast<int*>(inputs[2].getData()));
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
ContextProjectionBackward<Device>(out_grad_mat.get(),
in_grad_mat ? in_grad_mat.get() : nullptr,
w_grad_mat ? w_grad_mat.get() : nullptr,
auto out_grad_mat = outputs[0].matrix<Device>();
auto in_grad_mat =
!inputs[0].data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: inputs[0].matrix<Device>();
auto w_grad_mat = !inputs[1].data()
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: inputs[1].matrix<Device>();
auto seq_vec = inputs[2].vector<int, Device>();
ContextProjectionBackward<Device>(out_grad_mat,
in_grad_mat,
w_grad_mat,
seq_vec,
context_length_,
context_start_,
......@@ -253,6 +238,7 @@ private:
size_t total_pad_;
};
#if 0
/**
* \param inputs[0] input grad.
* \param inputs[1] input sequence.
......@@ -349,6 +335,7 @@ private:
size_t begin_pad_;
size_t total_pad_;
};
#endif
REGISTER_TYPED_FUNC(ContextProjectionForward,
CPU,
......@@ -363,6 +350,7 @@ REGISTER_TYPED_FUNC(ContextProjectionForward,
REGISTER_TYPED_FUNC(ContextProjectionBackward,
GPU,
ContextProjectionBackwardFunc);
#if 0
REGISTER_TYPED_FUNC(ContextProjectionBackwardData,
GPU,
ContextProjectionBackwardDataFunc);
......@@ -370,4 +358,5 @@ REGISTER_TYPED_FUNC(ContextProjectionBackwardWeight,
GPU,
ContextProjectionBackwardWeightFunc);
#endif
#endif
} // namespace paddle
......@@ -31,14 +31,15 @@ namespace paddle {
* \param[in] is_padding whether padding 0 or not.
*
*/
template <DeviceType Device>
void ContextProjectionForward(typename MatrixT<Device>::type* output,
const typename MatrixT<Device>::type* input,
const typename MatrixT<Device>::type* weight,
const typename SequenceT<Device>::type& sequence,
size_t context_length,
int context_start,
size_t begin_pad);
template <DeviceType DType>
void ContextProjectionForward(
typename Tensor<real, DType>::Matrix& output,
const typename Tensor<real, DType>::Matrix& input,
const typename Tensor<real, DType>::Matrix& weight,
const typename Tensor<int, DType>::Vector& sequence,
size_t context_length,
int context_start,
size_t begin_pad);
/**
* \brief Context Projection Backward.
......@@ -53,30 +54,31 @@ void ContextProjectionForward(typename MatrixT<Device>::type* output,
* \param[in] is_padding whether padding 0 or not.
*
*/
template <DeviceType Device>
void ContextProjectionBackward(typename MatrixT<Device>::type* out_grad,
typename MatrixT<Device>::type* in_grad,
typename MatrixT<Device>::type* w_grad,
const typename SequenceT<Device>::type& seq_vec,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad);
template <DeviceType DType>
void ContextProjectionBackward(
typename Tensor<real, DType>::Matrix& out_grad,
typename Tensor<real, DType>::Matrix& in_grad,
typename Tensor<real, DType>::Matrix& w_grad,
const typename Tensor<int, DType>::Vector& seq_vec,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad);
template <DeviceType Device>
template <DeviceType DType>
void ContextProjectionBackwardData(
typename MatrixT<Device>::type* out_grad,
typename MatrixT<Device>::type* in_grad,
const typename SequenceT<Device>::type& sequence,
typename Tensor<real, DType>::Matrix& out_grad,
typename Tensor<real, DType>::Matrix& in_grad,
const typename Tensor<int, DType>::Vector& sequence,
size_t context_length,
int context_start);
template <DeviceType Device>
template <DeviceType DType>
void ContextProjectionBackwardWeight(
typename MatrixT<Device>::type* out_grad,
typename MatrixT<Device>::type* w_grad,
const typename SequenceT<Device>::type& seq_vec,
typename Tensor<real, DType>::Matrix& out_grad,
typename Tensor<real, DType>::Matrix& w_grad,
const typename Tensor<int, DType>::Vector& seq_vec,
size_t context_length,
int context_start,
size_t total_pad,
......
......@@ -120,20 +120,19 @@ void hl_context_projection_forward(const real* input,
}
template <>
void ContextProjectionForward<DEVICE_TYPE_GPU>(GpuMatrix* output,
const GpuMatrix* input,
const GpuMatrix* weight,
void ContextProjectionForward<DEVICE_TYPE_GPU>(GpuMatrix& output,
const GpuMatrix& input,
const GpuMatrix& weight,
const GpuIVector& sequence,
size_t context_length,
int context_start,
size_t begin_pad) {
CHECK(input && output);
hl_context_projection_forward(input->getData(),
hl_context_projection_forward(input.getData(),
sequence.getData(),
weight ? weight->getData() : nullptr,
output->getData(),
weight ? weight.getData() : nullptr,
output.getData(),
sequence.getSize() - 1,
input->getWidth(),
input.getWidth(),
context_length,
context_start,
begin_pad);
......@@ -217,17 +216,16 @@ void hl_context_projection_backward_data(real* out_grad,
}
template <>
void ContextProjectionBackwardData<DEVICE_TYPE_GPU>(GpuMatrix* out_grad,
GpuMatrix* in_grad,
void ContextProjectionBackwardData<DEVICE_TYPE_GPU>(GpuMatrix& out_grad,
GpuMatrix& in_grad,
const GpuIVector& sequence,
size_t context_length,
int context_start) {
CHECK(in_grad && out_grad);
hl_context_projection_backward_data(out_grad->getData(),
hl_context_projection_backward_data(out_grad.getData(),
sequence.getData(),
in_grad->getData(),
in_grad.getData(),
sequence.getSize() - 1,
in_grad->getWidth(),
in_grad.getWidth(),
context_length,
context_start);
}
......@@ -348,19 +346,18 @@ void hl_context_projection_backward_weight(real* out_grad,
template <>
void ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(
GpuMatrix* out_grad,
GpuMatrix* w_grad,
GpuMatrix& out_grad,
GpuMatrix& w_grad,
const GpuIVector& seq_vec,
size_t context_length,
int context_start,
size_t total_pad,
size_t begin_pad) {
CHECK(out_grad && w_grad);
hl_context_projection_backward_weight(out_grad->getData(),
hl_context_projection_backward_weight(out_grad.getData(),
seq_vec.getData(),
w_grad->getData(),
w_grad.getData(),
seq_vec.getSize() - 1,
w_grad->getWidth(),
w_grad.getWidth(),
total_pad,
context_length,
context_start,
......@@ -368,16 +365,15 @@ void ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(
}
template <>
void ContextProjectionBackward<DEVICE_TYPE_GPU>(GpuMatrix* out_grad,
GpuMatrix* in_grad,
GpuMatrix* w_grad,
void ContextProjectionBackward<DEVICE_TYPE_GPU>(GpuMatrix& out_grad,
GpuMatrix& in_grad,
GpuMatrix& w_grad,
const GpuIVector& sequence,
size_t context_length,
int context_start,
size_t begin_pad,
bool is_padding,
size_t total_pad) {
CHECK(out_grad);
if (in_grad) {
ContextProjectionBackwardData<DEVICE_TYPE_GPU>(
out_grad,
......
......@@ -112,6 +112,8 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad,
}
/**
* \brief {o_0, o_1} = calc(i_0)
*
* \param inputs[0] input value.
* \param outputs[0] output value.
* \param outputs[1] denoms.
......@@ -125,27 +127,24 @@ public:
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(1, static_cast<int>(inputs.size()));
CHECK_EQ(2, static_cast<int>(outputs.size()));
CHECK_EQ(0, static_cast<int>(inouts.size()));
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)1, inputs.size());
CHECK_EQ((size_t)2, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == outputs[0].shape());
CHECK(inputs[0].shape() == outputs[1].shape());
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
CrossMapNormal<Device>(outputs[0].getData(),
outputs[1].getData(),
inputs[0].getData(),
CrossMapNormal<Device>(outputs[0].data<real>(),
outputs[1].data<real>(),
inputs[0].data<real>(),
samples,
channels,
height,
......@@ -162,6 +161,8 @@ private:
};
/**
* \brief {o_0} = calc(i_0, i_1, i_2, i_3)
*
* \param inputs[0] input value.
* \param inputs[1] output value.
* \param inputs[2] output grad.
......@@ -177,31 +178,29 @@ public:
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(4, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size()));
CHECK_EQ(0, static_cast<int>(inouts.size()));
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
}
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
CrossMapNormalGrad<Device>(outputs[0].getData(),
inputs[0].getData(),
inputs[1].getData(),
inputs[2].getData(),
inputs[3].getData(),
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)4, inputs.size());
CHECK_EQ((size_t)1, outputs.size());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == inputs[1].shape());
CHECK(inputs[0].shape() == inputs[2].shape());
CHECK(inputs[0].shape() == inputs[3].shape());
CHECK(inputs[0].shape() == outputs[0].shape());
// TODO(hedaoyuan): need support ASSIGN_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
CrossMapNormalGrad<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
inputs[1].data<real>(),
inputs[2].data<real>(),
inputs[3].data<real>(),
samples,
channels,
height,
......
......@@ -76,6 +76,20 @@ FuncConfig& FuncConfig::set<bool>(const std::string& key, bool v) {
return *this;
}
void BufferArgs::addArg(const Matrix& arg,
const TensorShape& shape,
ArgType argType) {
args_.push_back(std::make_shared<BufferArg>(arg, shape, argType));
}
void BufferArgs::addArg(const CpuSparseMatrix& arg, ArgType argType) {
args_.push_back(std::make_shared<SparseMatrixArg>(arg, argType));
}
void BufferArgs::addArg(const GpuSparseMatrix& arg, ArgType argType) {
args_.push_back(std::make_shared<SparseMatrixArg>(arg, argType));
}
ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_;
} // namespace paddle
......@@ -16,57 +16,17 @@ limitations under the License. */
#include <map>
#include <vector>
#include "BufferArg.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/ClassRegistrar.h"
namespace paddle {
enum DeviceType {
DEVICE_TYPE_UNSPECIFIED = 0,
DEVICE_TYPE_CPU = 1,
DEVICE_TYPE_GPU = 2,
};
template <DeviceType Device>
struct MatrixT;
template <>
struct MatrixT<DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
template <>
struct MatrixT<DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
template <DeviceType Device>
struct SequenceT;
template <>
struct SequenceT<DEVICE_TYPE_CPU> {
using type = CpuIVector;
};
template <>
struct SequenceT<DEVICE_TYPE_GPU> {
using type = GpuIVector;
};
typedef std::vector<size_t> Dims;
class Tensor {
public:
Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {}
real* getData() const { return buf_; }
real* buf_;
Dims dims_;
};
typedef std::vector<Tensor> Arguments;
/**
* Function Configuration.
* The argument type of Function::init.
* Follow-up will consider moving this data structure to Proto inside.
*/
class FuncConfig {
public:
union value {
......@@ -86,15 +46,70 @@ protected:
std::map<std::string, value> valueMap_;
};
/**
* Argument type for Function::calc().
* A BufferArgs contains a set of BufferArg,
* because Function can have multiple inputs and outputs.
*/
class BufferArgs {
public:
BufferArgs() {}
size_t size() const { return args_.size(); }
// add argument into BufferArgs
// Tensor can be Matrix, Vector, IVector.
// For inputs, do not need argType.
// For outputs, the argType needs to be specified as ASSIGN_TO or ADD_TO.
template <typename Tensor>
void addArg(const Tensor& arg, ArgType argType = UNSPECIFIED) {
args_.push_back(std::make_shared<BufferArg>(arg, argType));
}
// Add arg into BufferArgs and reshape the arg.
//
// For example, arg represents an image buffer,
// but Matrix can only represent a two-dimensional Tensor.
// So need an extra argument to describe the shape of the image buffer.
void addArg(const Matrix& arg,
const TensorShape& shape,
ArgType argType = UNSPECIFIED);
void addArg(const CpuSparseMatrix& arg, ArgType argType = UNSPECIFIED);
void addArg(const GpuSparseMatrix& arg, ArgType argType = UNSPECIFIED);
// get argument
const BufferArg& operator[](size_t num) const {
CHECK_LT(num, args_.size());
return *args_[num];
}
private:
std::vector<BufferArgPtr> args_;
};
/**
* \brief Base class for Function.
* The basic Function implementation requires override init and calc interfaces.
*
* Function inputs are readonly, Function outputs have two modes: ASSIGN_TO
* and ADD_TO.
* If output.getArgType() == ASSIGN_TO, this is assign mode, and the calculation
* result of Function assigned to the output BufferArg.
* If output.getArgType() == ADD_TO, this is add mode, and the calculation
* result of Function need added to the output BufferArg.
*
* For example:
* ASSIGN_TO: output = Function(inputs)
* ADD_TO: output += Function(inputs)
* If Function has more than one output, each output can have different modes.
*/
class FunctionBase {
public:
virtual ~FunctionBase() {}
virtual void init(const FuncConfig& config) {}
virtual void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) {}
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
static ClassRegistrar<FunctionBase> funcRegistrar_;
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Function.h"
#include <gtest/gtest.h>
namespace paddle {
template <DeviceType DType>
void FunctionApi(typename Tensor<real, DType>::Matrix& output,
const typename Tensor<real, DType>::Matrix& input);
template <>
void FunctionApi<DEVICE_TYPE_CPU>(CpuMatrix& output, const CpuMatrix& input) {
EXPECT_EQ(output.getHeight(), 100);
EXPECT_EQ(output.getWidth(), 200);
}
template <>
void FunctionApi<DEVICE_TYPE_GPU>(GpuMatrix& output, const GpuMatrix& input) {
EXPECT_EQ(output.getHeight(), 10);
EXPECT_EQ(output.getWidth(), 20);
}
template <DeviceType DType>
void Function(const BufferArgs& arguments) {
const auto input = arguments[0].matrix<DType>();
auto output = arguments[1].matrix<DType>();
FunctionApi<DType>(output, input);
}
TEST(Function, BufferArgs) {
CpuMatrix cpuInput = CpuMatrix(100, 200);
CpuMatrix cpuOutput = CpuMatrix(100, 200);
BufferArgs cpuArgments;
cpuArgments.addArg(cpuInput);
cpuArgments.addArg(cpuOutput);
Function<DEVICE_TYPE_CPU>(cpuArgments);
GpuMatrix gpuInput = GpuMatrix(10, 20);
GpuMatrix gpuOutput = GpuMatrix(10, 20);
BufferArgs gpuArgments;
gpuArgments.addArg(gpuInput);
gpuArgments.addArg(gpuOutput);
Function<DEVICE_TYPE_GPU>(gpuArgments);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
namespace paddle {
/**
* TensorShape used to represent shape of normal tensor.
*/
class TensorShape {
public:
TensorShape() : ndims_(0), nelements_(0) { initDims(0); }
TensorShape(size_t ndims) : ndims_(ndims), nelements_(1) { initDims(ndims); };
TensorShape(std::initializer_list<size_t> dims) {
ndims_ = dims.size();
initDims(ndims_);
dims_.assign(dims);
numElements();
};
TensorShape(const TensorShape& t)
: ndims_(t.ndims_), nelements_(t.nelements_) {
initDims(ndims_);
dims_.assign(t.dims_.begin(), t.dims_.end());
};
// get the size of specified dimension
size_t operator[](size_t dim) const {
CHECK_GE(dim, (size_t)0);
CHECK_LT(dim, ndims_);
return dims_[dim];
}
// set the size of specified dimension
void setDim(size_t dim, size_t size) {
CHECK_GE(dim, (size_t)0);
CHECK_LT(dim, ndims_);
dims_[dim] = size;
numElements();
}
// number of dimensions of the tensor
size_t ndims() const { return ndims_; }
size_t getElements() const { return nelements_; }
bool operator==(const TensorShape& t) const {
if (ndims() != t.ndims()) return false;
for (size_t i = 0; i < ndims(); i++) {
if (dims_[i] != t.dims_[i]) return false;
}
return true;
}
bool operator!=(const TensorShape& t) const { return !(*this == t); }
private:
// compute number of elements
void numElements() {
nelements_ = 1;
for (size_t n = 0; n < ndims_; n++) {
nelements_ *= dims_[n];
}
}
// init dims_
void initDims(size_t ndims) {
size_t count = ndims < 4 ? 4 : ndims;
dims_.assign(count, 1);
}
// number of dimensions
// ndims_ may be not equeal dims_.size()
size_t ndims_;
// number of elements
size_t nelements_;
std::vector<size_t> dims_;
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "TensorShape.h"
#include <gtest/gtest.h>
namespace paddle {
TEST(TensorShape, Constructor) {
TensorShape t1;
EXPECT_EQ(t1.ndims(), 0);
EXPECT_EQ(t1.getElements(), 0);
TensorShape t2(3);
EXPECT_EQ(t2.ndims(), 3);
EXPECT_EQ(t2.getElements(), 1);
TensorShape t3({8, 10});
EXPECT_EQ(t3.ndims(), 2);
EXPECT_EQ(t3.getElements(), 80);
TensorShape t4(t3);
EXPECT_EQ(t4.ndims(), t3.ndims());
EXPECT_EQ(t4.getElements(), t3.getElements());
TensorShape t5({1, 2, 3, 4, 5});
EXPECT_EQ(t5.ndims(), 5);
EXPECT_EQ(t5.getElements(), 120);
}
TEST(TensorShape, GetAndSet) {
TensorShape t({1, 2, 3});
EXPECT_EQ(t.ndims(), 3);
EXPECT_EQ(t.getElements(), 6);
EXPECT_EQ(t[1], 2);
t.setDim(1, 100);
EXPECT_EQ(t.getElements(), 300);
EXPECT_EQ(t[1], 100);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/math/Matrix.h"
namespace paddle {
enum ValueType {
VALUE_TYPE_INT32 = 0,
VALUE_TYPE_FLOAT = 1,
VALUE_TYPE_DOUBLE = 2,
VALUE_TYPE_BYTE = 3
};
enum DeviceType {
DEVICE_TYPE_UNSPECIFIED = 0,
DEVICE_TYPE_CPU = 1,
DEVICE_TYPE_GPU = 2
};
inline int sizeOfValuType(ValueType valueType) {
if (valueType == VALUE_TYPE_INT32) {
return 4;
} else if (valueType == VALUE_TYPE_FLOAT) {
return 4;
} else if (valueType == VALUE_TYPE_DOUBLE) {
return 8;
} else {
LOG(FATAL) << "Unknown type: " << valueType;
return 0;
}
}
template <typename T>
struct DataType;
template <>
struct DataType<float> {
static const ValueType value = VALUE_TYPE_FLOAT;
};
template <>
struct DataType<double> {
static const ValueType value = VALUE_TYPE_DOUBLE;
};
template <>
struct DataType<int> {
static const ValueType value = VALUE_TYPE_INT32;
};
namespace detail {
template <typename VType, DeviceType Device>
struct MatrixT;
template <>
struct MatrixT<real, DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
template <>
struct MatrixT<real, DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
template <>
struct MatrixT<int, DEVICE_TYPE_CPU> {
using type = void; // Not implemented
};
template <>
struct MatrixT<int, DEVICE_TYPE_GPU> {
using type = void; // Not implemented
};
template <typename VType, DeviceType Device>
struct VectorT;
template <>
struct VectorT<real, DEVICE_TYPE_CPU> {
using type = CpuVector;
};
template <>
struct VectorT<real, DEVICE_TYPE_GPU> {
using type = GpuVector;
};
template <>
struct VectorT<int, DEVICE_TYPE_CPU> {
using type = CpuIVector;
};
template <>
struct VectorT<int, DEVICE_TYPE_GPU> {
using type = GpuIVector;
};
} // namespace detail
template <typename VType, DeviceType DType>
struct Tensor {
typedef typename detail::MatrixT<VType, DType>::type Matrix;
typedef typename detail::VectorT<VType, DType>::type Vector;
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "TensorType.h"
#include <gtest/gtest.h>
namespace paddle {
TEST(TensorType, Matrix) {
Tensor<real, DEVICE_TYPE_CPU>::Matrix matrix(100, 200);
EXPECT_EQ(matrix.getHeight(), 100);
EXPECT_EQ(matrix.getWidth(), 200);
EXPECT_EQ(matrix.getElementCnt(), 100 * 200);
EXPECT_EQ(matrix.useGpu(), false);
Tensor<real, DEVICE_TYPE_GPU>::Matrix testGpu(100, 200);
EXPECT_EQ(testGpu.useGpu(), true);
}
TEST(TensorType, Vector) {
Tensor<real, DEVICE_TYPE_CPU>::Vector cpuVector(100);
Tensor<real, DEVICE_TYPE_GPU>::Vector gpuVector(100);
EXPECT_EQ(cpuVector.useGpu(), false);
EXPECT_EQ(gpuVector.useGpu(), true);
EXPECT_EQ(cpuVector.getSize(), 100);
EXPECT_EQ(gpuVector.getSize(), 100);
Tensor<int, DEVICE_TYPE_CPU>::Vector cpuIVector(100);
Tensor<int, DEVICE_TYPE_GPU>::Vector gpuIVector(100);
EXPECT_EQ(cpuIVector.useGpu(), false);
EXPECT_EQ(gpuIVector.useGpu(), true);
EXPECT_EQ(cpuIVector.getSize(), 100);
EXPECT_EQ(gpuIVector.getSize(), 100);
}
TEST(TensorType, EmptyMatrix) {
CpuMatrix empty(nullptr, 0, 0);
CpuMatrix nonEmpty(10, 10);
EXPECT_EQ(empty.isEmpty(), true);
EXPECT_EQ(nonEmpty.isEmpty(), false);
CHECK(nonEmpty);
auto function = [](const CpuMatrix& matrix) {
if (matrix) {
EXPECT_NE(matrix.getData(), nullptr);
} else {
EXPECT_EQ(matrix.getData(), nullptr);
}
};
function(empty);
function(nonEmpty);
}
} // namespace paddle
......@@ -110,9 +110,8 @@ void ContextProjection::forward() {
size_t input_dim = in_->value->getWidth();
size_t dim = out_->value->getWidth();
CHECK_EQ(dim, input_dim * config_.context_length());
size_t batch_size = in_->value->getHeight();
CHECK_EQ(static_cast<int>(forward_.size()), 1)
<< "Only one forward function here";
// size_t batch_size = in_->value->getHeight();
CHECK_EQ(forward_.size(), (size_t)1) << "Only one forward function here";
REGISTER_TIMER_INFO("ContextProjectionForward", getName().c_str());
bool is_padding = config_.trainable_padding();
......@@ -120,14 +119,16 @@ void ContextProjection::forward() {
auto w_ptr =
state_ ? state_.get() : is_padding ? weight_->getW().get() : nullptr;
auto start_pos = in_->sequenceStartPositions;
forward_[0]->calc({Tensor(in_->value->getData(), Dims{batch_size, input_dim}),
Tensor(w_ptr ? w_ptr->getData() : nullptr,
Dims{w_ptr ? w_ptr->getHeight() : 0, input_dim}),
Tensor(reinterpret_cast<real*>(
const_cast<int*>(start_pos->getData(useGpu_))),
Dims{start_pos->getSize()})},
{Tensor(out_->value->getData(), Dims{batch_size, dim})},
{});
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*in_->value);
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0,
input_dim));
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
outputs.addArg(*out_->value, ADD_TO);
forward_[0]->calc(inputs, outputs);
if (state_ && config_.context_start() < 0) {
CHECK_EQ(1, in_->getNumSequences());
......@@ -162,15 +163,17 @@ void ContextProjection::backward(const UpdateCallback& callback) {
bool is_padding = config_.trainable_padding();
auto start_pos = in_->sequenceStartPositions;
auto w_ptr = is_padding ? weight_->getWGrad() : nullptr;
backward_[0]->calc({Tensor(in_->grad ? in_->grad->getData() : nullptr,
Dims{batch_size, input_dim}),
Tensor(w_ptr ? w_ptr->getData() : nullptr,
Dims{w_ptr ? w_ptr->getHeight() : 0, input_dim}),
Tensor(reinterpret_cast<real*>(
const_cast<int*>(start_pos->getData(useGpu_))),
Dims{start_pos->getSize()})},
{Tensor(out_->grad->getData(), Dims{batch_size, dim})},
{});
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(CpuMatrix(
in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim));
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0,
input_dim));
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
outputs.addArg(*out_->grad, ADD_TO);
backward_[0]->calc(inputs, outputs);
if (config_.trainable_padding()) {
weight_->getParameterPtr()->incUpdate(callback);
......
......@@ -59,7 +59,6 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
void CMRProjectionNormLayer::forward(PassType passType) {
Layer::forward(passType);
/* malloc memory for the output_ if necessary */
/* note: one sample correspond to one row */
MatrixPtr input = inputLayers_[0]->getOutputValue();
......@@ -67,34 +66,36 @@ void CMRProjectionNormLayer::forward(PassType passType) {
int size = getSize();
resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
forward_[0]->calc(
{Tensor(input->getData(), dims_)},
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
{});
shape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_});
// prepare forward arguments
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), shape_);
outputs.addArg(*getOutputValue(), shape_, ASSIGN_TO);
outputs.addArg(*denoms_, shape_, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
}
void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
(void)callback;
if (NULL == inputLayers_[0]->getOutputGrad()) {
if (NULL == getInputGrad(0)) {
return;
}
/* Do derivation */
MatrixPtr preOutGrad = inputLayers_[0]->getOutputGrad();
MatrixPtr localGrad = getOutputGrad();
MatrixPtr localOutV = getOutputValue();
MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
backward_[0]->calc({Tensor(preOutV->getData(), dims_),
Tensor(localOutV->getData(), dims_),
Tensor(localGrad->getData(), dims_),
Tensor(denoms_->getData(), dims_)},
{Tensor(preOutGrad->getData(), dims_)},
{});
// prepare backward arguments
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), shape_);
inputs.addArg(*getOutputValue(), shape_);
inputs.addArg(*getOutputGrad(), shape_);
inputs.addArg(*denoms_, shape_);
outputs.addArg(*getInputGrad(0), shape_, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
} // namespace paddle
......@@ -41,6 +41,6 @@ public:
void backward(const UpdateCallback& callback = nullptr);
protected:
Dims dims_;
TensorShape shape_;
};
} // namespace paddle
......@@ -1091,6 +1091,10 @@ public:
TensorCpuApply<real>(*this, expr);
}
}
bool isEmpty() const { return data_ == nullptr; }
explicit operator bool() const { return !isEmpty(); }
};
inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册