提交 b9dfe8e7 编写于 作者: H Haonan 提交者: GitHub

Merge pull request #1231 from yu239/rotate_and_flip

One bug fix and two new features
......@@ -267,4 +267,16 @@ extern void hl_matrix_collect_shared_bias(real* B_d,
const int dimN,
real scale);
/**
* @brief Matrix rotation in 90 degrees
*
* @param[in] mat input matrix (M x N).
* @param[out] matRot output matrix (N x M).
* @param[in] dimM input matrix height.
* @param[in] dimN input matrix width.
* @param[in] clockWise rotation direction
*/
extern void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise);
#endif /* HL_MATRIX_H_ */
......@@ -106,4 +106,8 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
const int dimM,
const int dimN,
real scale) {}
inline void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise) {}
#endif // HL_MATRIX_STUB_H_
......@@ -840,3 +840,28 @@ void hl_matrix_collect_shared_bias(real* B_d,
(B_d, A_d, channel, dimM, dimN, dim, limit, scale);
CHECK_SYNC("hl_matrix_collect_shared_bias failed");
}
__global__ void keMatrixRotate(real* mat, real* matRot,
int dimM, int dimN, bool clockWise) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < dimM * dimN) {
int i = idx / dimN;
int j = idx % dimN;
if (clockWise) {
matRot[j * dimM + i] = mat[(dimM - i - 1) * dimN + j];
} else {
matRot[j * dimM + i] = mat[i * dimN + (dimN - j - 1)];
}
}
}
void hl_matrix_rotate(real *mat, real* matRot,
int dimM, int dimN, bool clockWise) {
CHECK_NOTNULL(mat);
CHECK_NOTNULL(matRot);
const int threads = 512;
const int blocks = DIVUP(dimM * dimN, threads);
keMatrixRotate<<< blocks, threads, 0, STREAM_DEFAULT >>>
(mat, matRot, dimM, dimN, clockWise);
CHECK_SYNC("hl_matrix_rotate failed");
}
......@@ -95,6 +95,9 @@ void FeatureMapExpandLayer::forward(PassType passType) {
void FeatureMapExpandLayer::backward(const UpdateCallback& callback) {
MatrixPtr inGrad = getInputGrad(0);
if (NULL == inGrad) {
return;
}
MatrixPtr outGrad = getOutputGrad();
size_t batchSize = getInput(0).getBatchSize();
int imgSize = inGrad->getWidth();
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "RotateLayer.h"
namespace paddle {
REGISTER_LAYER(rotate, RotateLayer);
bool RotateLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 1UL);
height_ = config_.height();
width_ = config_.width();
CHECK_GT(height_, 0);
CHECK_GT(width_, 0);
return true;
}
void RotateLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr input = getInputValue(0);
batchSize_ = input->getHeight();
size_ = input->getWidth();
CHECK_GE(size_, height_ * width_);
CHECK_EQ(size_ % (height_ * width_), 0)
<< "total size_ is not dividable by (height_ * width_), i.e., "
<< "channel number should be an integer";
channels_ = size_ / (height_ * width_);
resizeOutput(batchSize_, size_);
MatrixPtr outV = getOutputValue();
for (int b = 0; b < batchSize_; b++) { // for each input feat map
for (int c = 0; c < channels_; c++) { // for each feat channel
MatrixPtr inputSample =
Matrix::create(input->getData() + b * size_ + c * height_ * width_,
height_,
width_,
false,
useGpu_);
MatrixPtr outputSample =
Matrix::create(outV->getData() + b * size_ + c * height_ * width_,
width_,
height_,
false,
useGpu_);
inputSample->rotate(outputSample, false, true /* clock-wise */);
}
}
if (getInputGrad(0)) {
zeroGrad();
}
}
void RotateLayer::backward(const UpdateCallback& callback) {
(void)callback;
MatrixPtr outputGrad = getOutputGrad();
if (outputGrad == NULL) {
return;
}
// the grad should be rotated in the reverse direction
MatrixPtr preGrad = getInputGrad(0);
for (int b = 0; b < batchSize_; b++) { // for each input feat map
for (int c = 0; c < channels_; c++) { // for each feat channel
MatrixPtr inputSampleGrad =
Matrix::create(preGrad->getData() + b * size_ + c * height_ * width_,
height_,
width_,
false,
useGpu_);
MatrixPtr outputSampleGrad = Matrix::create(
outputGrad->getData() + b * size_ + c * height_ * width_,
width_,
height_,
false,
useGpu_);
MatrixPtr tmpGrad = nullptr;
outputSampleGrad->rotate(tmpGrad, true, false /* anti clock-wise */);
inputSampleGrad->add(*tmpGrad);
}
}
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* A layer for rotating a multi-channel feature map (M x N x C) in the spatial
* domain
* The rotation is 90 degrees in clock-wise for each channel
* \f[
* y(j,i,:) = x(M-i-1,j,:)
* \f]
* where \f$x\f$ is (M x N x C) input, and \f$y\f$ is (N x M x C) output.
*
* The config file api is rotate_layer
*
*/
class RotateLayer : public Layer {
public:
explicit RotateLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
private:
int batchSize_;
int size_;
int height_;
int width_;
int channels_;
};
} // namespace paddle
......@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle {
/**
* A layer for transposition.
* A layer for transposing a minibatch matrix.
* \f[
y = x^\mathrm{T}
* \f]
......
......@@ -1316,6 +1316,25 @@ TEST(Layer, ResizeLayer) {
}
}
TEST(Layer, RotateLayer) {
TestConfig config;
config.biasSize = 0;
config.layerConfig.set_type("rotate");
const int CHANNEL = 2;
const int HEIGHT = 8;
const int WIDTH = 4;
const int INPUT_SIZE = HEIGHT * WIDTH * CHANNEL;
config.layerConfig.set_size(INPUT_SIZE);
config.layerConfig.set_height(HEIGHT);
config.layerConfig.set_width(WIDTH);
config.inputDefs.push_back({INPUT_DATA, "layer_0", INPUT_SIZE, 0});
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "rotate", 100, false, useGpu);
}
}
TEST(Layer, NCELayer) {
TestConfig config;
size_t numClasses = 4;
......
......@@ -372,7 +372,7 @@ MatrixPtr CpuSparseMatrix::subMatrix(size_t startRow, size_t numRows) {
}
/* mem MUST be alloced outside (memAlloc=false) */
void CpuSparseMatrix::transpose(MatrixPtr matTrans, bool memAlloc) {
void CpuSparseMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
CHECK(!memAlloc);
CpuSparseMatrix* mat = dynamic_cast<CpuSparseMatrix*>(matTrans.get());
if (format_ == SPARSE_CSR) {
......
......@@ -201,7 +201,7 @@ public:
void zeroMem();
/// mem MUST be alloced outside (memAlloc=false)
void transpose(MatrixPtr matTrans, bool memAlloc);
void transpose(MatrixPtr& matTrans, bool memAlloc);
void mul(const Matrix& A, const Matrix& B, real alpha, real beta);
......
......@@ -274,6 +274,18 @@ real GpuMatrix::getSum() {
return sum;
}
real GpuMatrix::getMin() {
CHECK(isContiguous());
auto vec = GpuVector(height_ * width_, data_);
return vec.getMin();
}
real GpuMatrix::getMax() {
CHECK(isContiguous());
auto vec = GpuVector(height_ * width_, data_);
return vec.getMax();
}
void GpuMatrix::accumulateColSum(Matrix& src) {
CHECK_EQ(getWidth(), src.getWidth());
CHECK_EQ(getHeight(), (size_t)1);
......@@ -371,11 +383,13 @@ MatrixPtr GpuMatrix::getTranspose() {
}
}
void GpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) {
void GpuMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
if (memAlloc) {
matTrans = std::make_shared<GpuMatrix>(width_, height_);
} else {
CHECK(matTrans != NULL);
CHECK_EQ(matTrans->getHeight(), width_);
CHECK_EQ(matTrans->getWidth(), height_);
}
real* dataTrans = matTrans->getData();
real* data = getData();
......@@ -385,13 +399,27 @@ void GpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) {
hl_matrix_transpose(data, dataTrans, height_, width_, lda, ldc);
}
void GpuMatrix::rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
if (memAlloc) {
matRot = std::make_shared<GpuMatrix>(width_, height_);
} else {
CHECK(matRot != NULL);
CHECK_EQ(matRot->getHeight(), width_);
CHECK_EQ(matRot->getWidth(), height_);
}
real* dataRot = matRot->getData();
real* data = getData();
hl_matrix_rotate(data, dataRot, height_, width_, clockWise);
}
MatrixPtr GpuMatrix::getInverse() {
MatrixPtr matInv;
inverse(matInv, true);
return matInv;
}
void GpuMatrix::inverse(MatrixPtr matInv, bool memAlloc) {
void GpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) {
CHECK_EQ(height_, width_);
if (memAlloc) {
......@@ -1690,11 +1718,13 @@ MatrixPtr CpuMatrix::getTranspose() {
}
}
void CpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) {
void CpuMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
if (memAlloc) {
matTrans = std::make_shared<CpuMatrix>(width_, height_);
} else {
CHECK(matTrans != NULL);
CHECK_EQ(matTrans->getHeight(), width_);
CHECK_EQ(matTrans->getWidth(), height_);
}
real* dataTrans = matTrans->getData();
real* data = getData();
......@@ -1708,13 +1738,35 @@ void CpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) {
}
}
void CpuMatrix::rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
if (memAlloc) {
matRot = std::make_shared<CpuMatrix>(width_, height_);
} else {
CHECK(matRot != NULL);
CHECK_EQ(matRot->getHeight(), width_);
CHECK_EQ(matRot->getWidth(), height_);
}
real* dataRot = matRot->getData();
real* data = getData();
for (size_t i = 0; i < height_; i++) {
for (size_t j = 0; j < width_; j++) {
if (clockWise) {
dataRot[j * height_ + i] = data[(height_ - i - 1) * width_ + j];
} else {
dataRot[j * height_ + i] = data[i * width_ + (width_ - j - 1)];
}
}
}
}
MatrixPtr CpuMatrix::getInverse() {
MatrixPtr matInv;
inverse(matInv, true);
return matInv;
}
void CpuMatrix::inverse(MatrixPtr matInv, bool memAlloc) {
void CpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) {
CHECK_EQ(height_, width_);
if (memAlloc) {
......
......@@ -372,7 +372,27 @@ public:
* allocate matTrans' memory outside, then set memAlloc as false;
* else set as true.
*/
virtual void transpose(MatrixPtr matTrans, bool memAlloc) {
virtual void transpose(MatrixPtr& matTrans, bool memAlloc) {
LOG(FATAL) << "Not implemented";
}
/**
* @brief rotate 90 degrees in clock-wise if clockWise=true;
* otherwise rotate in anti clock-wise
* clock-wise:
* \f[
* y(j,i) = x(M-i-1,j)
* \f]
* anti clock-wise:
* \f[
* y(j,i) = x(i, N-1-j)
* \f]
* where \f$x\f$ is (M x N) input, and \f$y\f$ is (N x M) output.
*
* allocate matRot' memory outside, then set memAlloc as false;
* else set as true.
*/
virtual void rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
LOG(FATAL) << "Not implemented";
}
......@@ -387,7 +407,7 @@ public:
* if allocate matInv's memory outside, then set memAlloc as false;
* else set as true.
*/
virtual void inverse(MatrixPtr matInv, bool memAlloc) {
virtual void inverse(MatrixPtr& matInv, bool memAlloc) {
LOG(FATAL) << "Not implemented";
}
......@@ -1169,11 +1189,15 @@ public:
void accumulateColSum(Matrix& src);
real getAbsSum();
real getMin();
real getMax();
MatrixPtr getTranspose();
void transpose(MatrixPtr matTrans, bool memAlloc);
void transpose(MatrixPtr& matTrans, bool memAlloc);
void rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise);
MatrixPtr getInverse();
void inverse(MatrixPtr matInv, bool memAlloc);
void inverse(MatrixPtr& matInv, bool memAlloc);
/// add b to each sample of this.
void addBias(Matrix& b, real scale);
......@@ -1485,10 +1509,11 @@ public:
real getAbsSum();
MatrixPtr getTranspose();
void transpose(MatrixPtr matTrans, bool memAlloc);
void transpose(MatrixPtr& matTrans, bool memAlloc);
void rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise);
MatrixPtr getInverse();
void inverse(MatrixPtr matInv, bool memAlloc);
void inverse(MatrixPtr& matInv, bool memAlloc);
void copyFrom(const Matrix& src);
......
......@@ -497,7 +497,7 @@ void GpuSparseMatrix::setRow(size_t row,
SparseValueType GpuSparseMatrix::getValueType() const { return valueType_; }
void GpuSparseMatrix::transpose(MatrixPtr matTrans, bool memAlloc) {
void GpuSparseMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
CHECK_EQ(format_, SPARSE_CSC);
int nnz = sMatrix_->nnz;
if (memAlloc) {
......
......@@ -109,7 +109,7 @@ public:
MatrixPtr getTranspose();
/// B = A'
void transpose(MatrixPtr matTrans, bool memAlloc);
void transpose(MatrixPtr& matTrans, bool memAlloc);
void copyFrom(const Matrix& src);
void copyFrom(const Matrix& src, hl_stream_t stream);
......
......@@ -248,11 +248,13 @@ TEST(Matrix, SparseMatrixTranspose) {
/*dense matrix transpose*/
CpuMatrixPtr matC(new CpuMatrix(height, width));
matC->copyFrom(*matA);
CpuMatrixPtr matD(new CpuMatrix(width, height));
MatrixPtr matD(new CpuMatrix(width, height));
matC->transpose(matD, false);
/*check result*/
checkSMatrixEqual2Dense(
std::dynamic_pointer_cast<CpuSparseMatrix>(matB), matD);
std::dynamic_pointer_cast<CpuSparseMatrix>(matB),
std::dynamic_pointer_cast<CpuMatrix>(matD));
}
}
}
......
......@@ -105,6 +105,21 @@ void testMatrixGetSum(int height, int width) {
EXPECT_LE(fabs(cpuSum - gpuSum), err);
}
void testMatrixGetMinMax(int height, int width) {
MatrixPtr cpuInput = std::make_shared<CpuMatrix>(height, width);
MatrixPtr gpuInput = std::make_shared<GpuMatrix>(height, width);
cpuInput->randomizeUniform();
gpuInput->copyFrom(*cpuInput);
real cpuMin = cpuInput->getMin();
real gpuMin = gpuInput->getMin();
real cpuMax = cpuInput->getMax();
real gpuMax = gpuInput->getMax();
EXPECT_EQ(cpuMin, gpuMin);
EXPECT_EQ(cpuMax, gpuMax);
}
void testMatrixZeroAtOffset(int height, int width) {
MatrixPtr cpuA = std::make_shared<CpuMatrix>(height, width);
MatrixPtr gpuA = std::make_shared<GpuMatrix>(height, width);
......@@ -161,11 +176,29 @@ void testMatrixTranspose(int height, int width) {
cpu->randomizeUniform();
gpu->copyFrom(*cpu);
cpu->transpose(cpuT, false);
gpu->transpose(gpuT, false);
gpu->transpose(gpuT, true);
TensorCheckEqual(*cpuT, *gpuT);
}
void testMatrixRotate(int height, int width) {
MatrixPtr cpu = std::make_shared<CpuMatrix>(height, width);
MatrixPtr gpu = std::make_shared<GpuMatrix>(height, width);
MatrixPtr cpuR = std::make_shared<CpuMatrix>(width, height);
MatrixPtr gpuR = std::make_shared<GpuMatrix>(width, height);
cpu->randomizeUniform();
gpu->copyFrom(*cpu);
cpu->rotate(cpuR, false, true);
gpu->rotate(gpuR, true, true);
TensorCheckEqual(*cpuR, *gpuR);
cpu->rotate(cpuR, true, false);
gpu->rotate(gpuR, false, false);
TensorCheckEqual(*cpuR, *gpuR);
}
void testMatrixInverse(int height) {
MatrixPtr cpu = std::make_shared<CpuMatrix>(height, height);
MatrixPtr gpu = std::make_shared<GpuMatrix>(height, height);
......@@ -181,7 +214,7 @@ void testMatrixInverse(int height) {
cpu->add(*outputCheck);
gpu->copyFrom(*cpu);
cpu->inverse(cpuI, false);
cpu->inverse(cpuI, true);
gpu->inverse(gpuI, false);
TensorCheckErr(*cpuI, *gpuI);
......@@ -200,6 +233,7 @@ TEST(Matrix, unary) {
testMatrixZeroAtOffset(height, width);
testMatrixGetSum(height, width);
testMatrixTranspose(height, width);
testMatrixRotate(height, width);
}
// inverse
testMatrixInverse(height);
......
......@@ -830,7 +830,6 @@ class Pool(Cfg):
channels,
size_x,
size_y=None,
img_width=None,
start=None,
stride=None, # 1 by defalut in protobuf
stride_y=None,
......@@ -1968,6 +1967,18 @@ class ResizeLayer(LayerBase):
'ResizeLayer must have one and only one input')
@config_layer('rotate')
class RotateLayer(LayerBase):
def __init__(self, name, inputs, height, width, device=None):
super(RotateLayer, self).__init__(
name, 'rotate', 0, inputs=inputs, device=device)
config_assert(
len(self.inputs) == 1,
'RotateLayer must have one and only one input')
self.set_layer_height_width(height, width)
self.set_layer_size(self.get_input_layer(0).size)
@config_layer('blockexpand')
class BlockExpandLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
......
......@@ -70,6 +70,7 @@ __all__ = [
'interpolation_layer',
'bilinear_interp_layer',
'trans_layer',
'rotate_layer',
'sum_to_one_norm_layer',
'get_output_layer',
'LayerType',
......@@ -154,6 +155,7 @@ class LayerType(object):
POWER_LAYER = 'power'
SCALING_LAYER = 'scaling'
TRANS_LAYER = 'trans'
ROTATE_LAYER = 'rotate'
OUT_PROD_LAYER = 'out_prod'
FEATURE_MAP_EXPAND_LAYER = 'featmap_expand'
......@@ -1642,7 +1644,7 @@ def scaling_layer(input, weight, name=None, layer_attr=None):
@layer_support()
def trans_layer(input, name=None, layer_attr=None):
"""
A layer for transposition.
A layer for transposing a minibatch matrix.
.. math::
y = x^\mathrm{T}
......@@ -1673,6 +1675,52 @@ def trans_layer(input, name=None, layer_attr=None):
name, LayerType.TRANS_LAYER, parents=[input], size=input.size)
@wrap_name_default()
@layer_support()
def rotate_layer(input, height, width, name=None, layer_attr=None):
"""
A layer for rotating 90 degrees (clock-wise) for each feature channel,
usually used when the input sample is some image or feature map.
.. math::
y(j,i,:) = x(M-i-1,j,:)
where :math:`x` is (M x N x C) input, and :math:`y` is (N x M x C) output.
The example usage is:
.. code-block:: python
rot = rotate_layer(input=layer,
height=100,
width=100)
:param input: Input layer.
:type input: LayerOutput
:param height: The height of the sample matrix
:type height: int
:param name: Layer name.
:type name: basestring
:param layer_attr: extra layer attributes.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput)
l = Layer(
name=name,
height=height,
width=width,
type=LayerType.ROTATE_LAYER,
inputs=[input.name],
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name=name,
layer_type=LayerType.ROTATE_LAYER,
parents=[input],
size=l.config.size)
@wrap_name_default()
@layer_support()
def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None):
......@@ -4746,6 +4794,7 @@ def cross_entropy_with_selfnorm(input,
layer_attr=None):
"""
A loss layer for multi class entropy with selfnorm.
Input should be a vector of positive numbers, without normalization.
.. code-block:: python
......
......@@ -39,6 +39,7 @@ z1 = mixed_layer(
assert z1.size > 0
y2 = fc_layer(input=y, size=15)
z2 = rotate_layer(input=y2, height=5, width=3)
cos1 = cos_sim(a=x1, b=y1)
cos3 = cos_sim(a=x1, b=y2, size=3)
......@@ -46,7 +47,7 @@ cos3 = cos_sim(a=x1, b=y2, size=3)
linear_comb = linear_comb_layer(weights=x1, vectors=y2, size=3)
out = fc_layer(
input=[cos1, cos3, linear_comb, z, z1],
input=[cos1, cos3, linear_comb, z, z1, z2],
size=num_classes,
act=SoftmaxActivation())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册