提交 781b85b5 编写于 作者: H Haonan

rotate_layer and flip_layer * added getMin and getMax for GpuMatrix * gru_step_layer parameter name

上级 c1f9cd9d
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
/** /**
* A layer for transposition. * A layer for transposing a minibatch matrix.
* \f[ * \f[
y = x^\mathrm{T} y = x^\mathrm{T}
* \f] * \f]
......
...@@ -1316,6 +1316,21 @@ TEST(Layer, ResizeLayer) { ...@@ -1316,6 +1316,21 @@ TEST(Layer, ResizeLayer) {
} }
} }
TEST(Layer, RotateLayer) {
TestConfig config;
config.biasSize = 0;
config.layerConfig.set_type("rotate");
const int INPUT_SIZE = 64; // height * width
config.layerConfig.set_size(INPUT_SIZE);
config.layerConfig.set_height(32);
config.inputDefs.push_back({INPUT_DATA, "layer_0", INPUT_SIZE, 0});
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "rotate", 100, false, useGpu);
}
}
TEST(Layer, NCELayer) { TEST(Layer, NCELayer) {
TestConfig config; TestConfig config;
size_t numClasses = 4; size_t numClasses = 4;
......
...@@ -372,7 +372,7 @@ MatrixPtr CpuSparseMatrix::subMatrix(size_t startRow, size_t numRows) { ...@@ -372,7 +372,7 @@ MatrixPtr CpuSparseMatrix::subMatrix(size_t startRow, size_t numRows) {
} }
/* mem MUST be alloced outside (memAlloc=false) */ /* mem MUST be alloced outside (memAlloc=false) */
void CpuSparseMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { void CpuSparseMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
CHECK(!memAlloc); CHECK(!memAlloc);
CpuSparseMatrix* mat = dynamic_cast<CpuSparseMatrix*>(matTrans.get()); CpuSparseMatrix* mat = dynamic_cast<CpuSparseMatrix*>(matTrans.get());
if (format_ == SPARSE_CSR) { if (format_ == SPARSE_CSR) {
......
...@@ -201,7 +201,7 @@ public: ...@@ -201,7 +201,7 @@ public:
void zeroMem(); void zeroMem();
/// mem MUST be alloced outside (memAlloc=false) /// mem MUST be alloced outside (memAlloc=false)
void transpose(MatrixPtr matTrans, bool memAlloc); void transpose(MatrixPtr& matTrans, bool memAlloc);
void mul(const Matrix& A, const Matrix& B, real alpha, real beta); void mul(const Matrix& A, const Matrix& B, real alpha, real beta);
......
...@@ -274,6 +274,18 @@ real GpuMatrix::getSum() { ...@@ -274,6 +274,18 @@ real GpuMatrix::getSum() {
return sum; return sum;
} }
real GpuMatrix::getMin() {
CHECK(isContiguous());
auto vec = GpuVector(height_ * width_, data_);
return vec.getMin();
}
real GpuMatrix::getMax() {
CHECK(isContiguous());
auto vec = GpuVector(height_ * width_, data_);
return vec.getMax();
}
void GpuMatrix::accumulateColSum(Matrix& src) { void GpuMatrix::accumulateColSum(Matrix& src) {
CHECK_EQ(getWidth(), src.getWidth()); CHECK_EQ(getWidth(), src.getWidth());
CHECK_EQ(getHeight(), (size_t)1); CHECK_EQ(getHeight(), (size_t)1);
...@@ -371,7 +383,7 @@ MatrixPtr GpuMatrix::getTranspose() { ...@@ -371,7 +383,7 @@ MatrixPtr GpuMatrix::getTranspose() {
} }
} }
void GpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { void GpuMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
if (memAlloc) { if (memAlloc) {
matTrans = std::make_shared<GpuMatrix>(width_, height_); matTrans = std::make_shared<GpuMatrix>(width_, height_);
} else { } else {
...@@ -385,13 +397,29 @@ void GpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { ...@@ -385,13 +397,29 @@ void GpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) {
hl_matrix_transpose(data, dataTrans, height_, width_, lda, ldc); hl_matrix_transpose(data, dataTrans, height_, width_, lda, ldc);
} }
void GpuMatrix::rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
if (memAlloc) {
matRot = std::make_shared<GpuMatrix>(width_, height_);
} else {
CHECK(matRot != NULL);
}
MatrixPtr cpuMat = std::make_shared<CpuMatrix>(height_, width_);
cpuMat->copyFrom(*this);
MatrixPtr cpuMatRot = std::make_shared<CpuMatrix>(width_, height_);
cpuMat->rotate(cpuMatRot, false, clockWise);
matRot->copyFrom(*cpuMatRot);
}
MatrixPtr GpuMatrix::getInverse() { MatrixPtr GpuMatrix::getInverse() {
MatrixPtr matInv; MatrixPtr matInv;
inverse(matInv, true); inverse(matInv, true);
return matInv; return matInv;
} }
void GpuMatrix::inverse(MatrixPtr matInv, bool memAlloc) { void GpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) {
CHECK_EQ(height_, width_); CHECK_EQ(height_, width_);
if (memAlloc) { if (memAlloc) {
...@@ -1690,7 +1718,7 @@ MatrixPtr CpuMatrix::getTranspose() { ...@@ -1690,7 +1718,7 @@ MatrixPtr CpuMatrix::getTranspose() {
} }
} }
void CpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { void CpuMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
if (memAlloc) { if (memAlloc) {
matTrans = std::make_shared<CpuMatrix>(width_, height_); matTrans = std::make_shared<CpuMatrix>(width_, height_);
} else { } else {
...@@ -1708,13 +1736,35 @@ void CpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { ...@@ -1708,13 +1736,35 @@ void CpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) {
} }
} }
void CpuMatrix::rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
if (memAlloc) {
matRot = std::make_shared<CpuMatrix>(width_, height_);
} else {
CHECK(matRot != NULL);
}
real* dataRot = matRot->getData();
real* data = getData();
int lda = getStride();
int ldc = matRot->getStride();
for (size_t i = 0; i < height_; i++) {
for (size_t j = 0; j < width_; j++) {
if (clockWise) {
dataRot[j * ldc + i] = data[(height_ - i - 1) * lda + j];
} else {
dataRot[j * ldc + i] = data[i * lda + (width_ - j - 1)];
}
}
}
}
MatrixPtr CpuMatrix::getInverse() { MatrixPtr CpuMatrix::getInverse() {
MatrixPtr matInv; MatrixPtr matInv;
inverse(matInv, true); inverse(matInv, true);
return matInv; return matInv;
} }
void CpuMatrix::inverse(MatrixPtr matInv, bool memAlloc) { void CpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) {
CHECK_EQ(height_, width_); CHECK_EQ(height_, width_);
if (memAlloc) { if (memAlloc) {
......
...@@ -372,7 +372,17 @@ public: ...@@ -372,7 +372,17 @@ public:
* allocate matTrans' memory outside, then set memAlloc as false; * allocate matTrans' memory outside, then set memAlloc as false;
* else set as true. * else set as true.
*/ */
virtual void transpose(MatrixPtr matTrans, bool memAlloc) { virtual void transpose(MatrixPtr& matTrans, bool memAlloc) {
LOG(FATAL) << "Not implemented";
}
/**
* @brief rotate clock-wise.
*
* allocate matTrans' memory outside, then set memAlloc as false;
* else set as true.
*/
virtual void rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
...@@ -387,7 +397,7 @@ public: ...@@ -387,7 +397,7 @@ public:
* if allocate matInv's memory outside, then set memAlloc as false; * if allocate matInv's memory outside, then set memAlloc as false;
* else set as true. * else set as true.
*/ */
virtual void inverse(MatrixPtr matInv, bool memAlloc) { virtual void inverse(MatrixPtr& matInv, bool memAlloc) {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
...@@ -1169,11 +1179,15 @@ public: ...@@ -1169,11 +1179,15 @@ public:
void accumulateColSum(Matrix& src); void accumulateColSum(Matrix& src);
real getAbsSum(); real getAbsSum();
real getMin();
real getMax();
MatrixPtr getTranspose(); MatrixPtr getTranspose();
void transpose(MatrixPtr matTrans, bool memAlloc); void transpose(MatrixPtr& matTrans, bool memAlloc);
void rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise);
MatrixPtr getInverse(); MatrixPtr getInverse();
void inverse(MatrixPtr matInv, bool memAlloc); void inverse(MatrixPtr& matInv, bool memAlloc);
/// add b to each sample of this. /// add b to each sample of this.
void addBias(Matrix& b, real scale); void addBias(Matrix& b, real scale);
...@@ -1485,10 +1499,11 @@ public: ...@@ -1485,10 +1499,11 @@ public:
real getAbsSum(); real getAbsSum();
MatrixPtr getTranspose(); MatrixPtr getTranspose();
void transpose(MatrixPtr matTrans, bool memAlloc); void transpose(MatrixPtr& matTrans, bool memAlloc);
void rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise);
MatrixPtr getInverse(); MatrixPtr getInverse();
void inverse(MatrixPtr matInv, bool memAlloc); void inverse(MatrixPtr& matInv, bool memAlloc);
void copyFrom(const Matrix& src); void copyFrom(const Matrix& src);
......
...@@ -497,7 +497,7 @@ void GpuSparseMatrix::setRow(size_t row, ...@@ -497,7 +497,7 @@ void GpuSparseMatrix::setRow(size_t row,
SparseValueType GpuSparseMatrix::getValueType() const { return valueType_; } SparseValueType GpuSparseMatrix::getValueType() const { return valueType_; }
void GpuSparseMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { void GpuSparseMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
CHECK_EQ(format_, SPARSE_CSC); CHECK_EQ(format_, SPARSE_CSC);
int nnz = sMatrix_->nnz; int nnz = sMatrix_->nnz;
if (memAlloc) { if (memAlloc) {
......
...@@ -109,7 +109,7 @@ public: ...@@ -109,7 +109,7 @@ public:
MatrixPtr getTranspose(); MatrixPtr getTranspose();
/// B = A' /// B = A'
void transpose(MatrixPtr matTrans, bool memAlloc); void transpose(MatrixPtr& matTrans, bool memAlloc);
void copyFrom(const Matrix& src); void copyFrom(const Matrix& src);
void copyFrom(const Matrix& src, hl_stream_t stream); void copyFrom(const Matrix& src, hl_stream_t stream);
......
...@@ -248,11 +248,13 @@ TEST(Matrix, SparseMatrixTranspose) { ...@@ -248,11 +248,13 @@ TEST(Matrix, SparseMatrixTranspose) {
/*dense matrix transpose*/ /*dense matrix transpose*/
CpuMatrixPtr matC(new CpuMatrix(height, width)); CpuMatrixPtr matC(new CpuMatrix(height, width));
matC->copyFrom(*matA); matC->copyFrom(*matA);
CpuMatrixPtr matD(new CpuMatrix(width, height)); MatrixPtr matD(new CpuMatrix(width, height));
matC->transpose(matD, false); matC->transpose(matD, false);
/*check result*/ /*check result*/
checkSMatrixEqual2Dense( checkSMatrixEqual2Dense(
std::dynamic_pointer_cast<CpuSparseMatrix>(matB), matD); std::dynamic_pointer_cast<CpuSparseMatrix>(matB),
std::dynamic_pointer_cast<CpuMatrix>(matD));
} }
} }
} }
......
...@@ -105,6 +105,21 @@ void testMatrixGetSum(int height, int width) { ...@@ -105,6 +105,21 @@ void testMatrixGetSum(int height, int width) {
EXPECT_LE(fabs(cpuSum - gpuSum), err); EXPECT_LE(fabs(cpuSum - gpuSum), err);
} }
void testMatrixGetMinMax(int height, int width) {
MatrixPtr cpuInput = std::make_shared<CpuMatrix>(height, width);
MatrixPtr gpuInput = std::make_shared<GpuMatrix>(height, width);
cpuInput->randomizeUniform();
gpuInput->copyFrom(*cpuInput);
real cpuMin = cpuInput->getMin();
real gpuMin = gpuInput->getMin();
real cpuMax = cpuInput->getMax();
real gpuMax = gpuInput->getMax();
EXPECT_EQ(cpuMin, gpuMin);
EXPECT_EQ(cpuMax, gpuMax);
}
void testMatrixZeroAtOffset(int height, int width) { void testMatrixZeroAtOffset(int height, int width) {
MatrixPtr cpuA = std::make_shared<CpuMatrix>(height, width); MatrixPtr cpuA = std::make_shared<CpuMatrix>(height, width);
MatrixPtr gpuA = std::make_shared<GpuMatrix>(height, width); MatrixPtr gpuA = std::make_shared<GpuMatrix>(height, width);
...@@ -181,7 +196,7 @@ void testMatrixInverse(int height) { ...@@ -181,7 +196,7 @@ void testMatrixInverse(int height) {
cpu->add(*outputCheck); cpu->add(*outputCheck);
gpu->copyFrom(*cpu); gpu->copyFrom(*cpu);
cpu->inverse(cpuI, false); cpu->inverse(cpuI, true);
gpu->inverse(gpuI, false); gpu->inverse(gpuI, false);
TensorCheckErr(*cpuI, *gpuI); TensorCheckErr(*cpuI, *gpuI);
......
...@@ -830,7 +830,6 @@ class Pool(Cfg): ...@@ -830,7 +830,6 @@ class Pool(Cfg):
channels, channels,
size_x, size_x,
size_y=None, size_y=None,
img_width=None,
start=None, start=None,
stride=None, # 1 by defalut in protobuf stride=None, # 1 by defalut in protobuf
stride_y=None, stride_y=None,
...@@ -1834,6 +1833,7 @@ class PoolLayer(LayerBase): ...@@ -1834,6 +1833,7 @@ class PoolLayer(LayerBase):
pool_conf.channels) pool_conf.channels)
@config_layer('spp') @config_layer('spp')
class SpatialPyramidPoolLayer(LayerBase): class SpatialPyramidPoolLayer(LayerBase):
def __init__(self, name, inputs, **xargs): def __init__(self, name, inputs, **xargs):
...@@ -1968,6 +1968,18 @@ class ResizeLayer(LayerBase): ...@@ -1968,6 +1968,18 @@ class ResizeLayer(LayerBase):
'ResizeLayer must have one and only one input') 'ResizeLayer must have one and only one input')
@config_layer('rotate')
class RotateLayer(LayerBase):
def __init__(self, name, inputs, height, device=None):
super(RotateLayer, self).__init__(
name, 'rotate', 0, inputs=inputs, device=device)
config_assert(
len(self.inputs) == 1,
'RotateLayer must have one and only one input')
self.config.height = height
self.set_layer_size(self.get_input_layer(0).size)
@config_layer('blockexpand') @config_layer('blockexpand')
class BlockExpandLayer(LayerBase): class BlockExpandLayer(LayerBase):
def __init__(self, name, inputs, **xargs): def __init__(self, name, inputs, **xargs):
......
...@@ -70,6 +70,8 @@ __all__ = [ ...@@ -70,6 +70,8 @@ __all__ = [
'interpolation_layer', 'interpolation_layer',
'bilinear_interp_layer', 'bilinear_interp_layer',
'trans_layer', 'trans_layer',
'rotate_layer',
'flip_layer',
'sum_to_one_norm_layer', 'sum_to_one_norm_layer',
'get_output_layer', 'get_output_layer',
'LayerType', 'LayerType',
...@@ -154,6 +156,7 @@ class LayerType(object): ...@@ -154,6 +156,7 @@ class LayerType(object):
POWER_LAYER = 'power' POWER_LAYER = 'power'
SCALING_LAYER = 'scaling' SCALING_LAYER = 'scaling'
TRANS_LAYER = 'trans' TRANS_LAYER = 'trans'
ROTATE_LAYER = 'rotate'
OUT_PROD_LAYER = 'out_prod' OUT_PROD_LAYER = 'out_prod'
FEATURE_MAP_EXPAND_LAYER = 'featmap_expand' FEATURE_MAP_EXPAND_LAYER = 'featmap_expand'
...@@ -1642,7 +1645,7 @@ def scaling_layer(input, weight, name=None, layer_attr=None): ...@@ -1642,7 +1645,7 @@ def scaling_layer(input, weight, name=None, layer_attr=None):
@layer_support() @layer_support()
def trans_layer(input, name=None, layer_attr=None): def trans_layer(input, name=None, layer_attr=None):
""" """
A layer for transposition. A layer for transposing a minibatch matrix.
.. math:: .. math::
y = x^\mathrm{T} y = x^\mathrm{T}
...@@ -1673,6 +1676,87 @@ def trans_layer(input, name=None, layer_attr=None): ...@@ -1673,6 +1676,87 @@ def trans_layer(input, name=None, layer_attr=None):
name, LayerType.TRANS_LAYER, parents=[input], size=input.size) name, LayerType.TRANS_LAYER, parents=[input], size=input.size)
@wrap_name_default()
@layer_support()
def rotate_layer(input, height, name=None, layer_attr=None):
"""
A layer for rotation (clock-wise), usually used when the input sample
is some image or map.
.. math::
y(j,i) = x(M-i-1,j)
where :math:`x` is (M x N) input, and :math:`y` is (N x M) output.
The example usage is:
.. code-block:: python
rot = rotate_layer(input=layer,
height=100)
:param input: Input layer.
:type input: LayerOutput
:param height: The height of the sample matrix
:type height: int
:param name: Layer name.
:type name: basestring
:param layer_attr: extra layer attributes.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput)
l = Layer(name=name,
height=height,
type=LayerType.ROTATE_LAYER,
inputs=[input.name],
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name=name,
layer_type=LayerType.ROTATE_LAYER,
parents=[input],
size=l.config.size)
@wrap_name_default()
@layer_support()
def flip_layer(input, height, name=None, layer_attr=None):
"""
A layer for flipping the matrix w.r.t the matrix center.
It's essentially rotating the matrix twice.
Used for input as image or map.
.. math::
y(i,j) = x(M-i-1, N-j-1)
where :math:`x` is (M x N) input, and :math:`y` is (M x N) output.
The example usage is:
.. code-block:: python
flip = flip_layer(input=layer,
height=100)
:param input: Input layer.
:type input: LayerOutput
:param height: The height of the sample matrix
:type height: int
:param name: Layer name.
:type name: basestring
:param layer_attr: extra layer attributes.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput)
return rotate_layer(input=rotate_layer(input=input,
height=height),
height=height,
name=name,
layer_attr=layer_attr)
@wrap_name_default() @wrap_name_default()
@layer_support() @layer_support()
def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None): def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None):
...@@ -4739,6 +4823,7 @@ def cross_entropy_with_selfnorm(input, ...@@ -4739,6 +4823,7 @@ def cross_entropy_with_selfnorm(input,
layer_attr=None): layer_attr=None):
""" """
A loss layer for multi class entropy with selfnorm. A loss layer for multi class entropy with selfnorm.
Input should be a vector of positive numbers, without normalization.
.. code-block:: python .. code-block:: python
......
...@@ -39,6 +39,10 @@ z1 = mixed_layer( ...@@ -39,6 +39,10 @@ z1 = mixed_layer(
assert z1.size > 0 assert z1.size > 0
y2 = fc_layer(input=y, size=15) y2 = fc_layer(input=y, size=15)
z2 = rotate_layer(input=y2,
height=5)
z3 = flip_layer(input=y2,
height=3)
cos1 = cos_sim(a=x1, b=y1) cos1 = cos_sim(a=x1, b=y1)
cos3 = cos_sim(a=x1, b=y2, size=3) cos3 = cos_sim(a=x1, b=y2, size=3)
...@@ -46,7 +50,7 @@ cos3 = cos_sim(a=x1, b=y2, size=3) ...@@ -46,7 +50,7 @@ cos3 = cos_sim(a=x1, b=y2, size=3)
linear_comb = linear_comb_layer(weights=x1, vectors=y2, size=3) linear_comb = linear_comb_layer(weights=x1, vectors=y2, size=3)
out = fc_layer( out = fc_layer(
input=[cos1, cos3, linear_comb, z, z1], input=[cos1, cos3, linear_comb, z, z1, z2, z3],
size=num_classes, size=num_classes,
act=SoftmaxActivation()) act=SoftmaxActivation())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册