提交 2558c3f1 编写于 作者: H Haonan

revisions according to reviews

上级 b4c1d175
...@@ -267,4 +267,16 @@ extern void hl_matrix_collect_shared_bias(real* B_d, ...@@ -267,4 +267,16 @@ extern void hl_matrix_collect_shared_bias(real* B_d,
const int dimN, const int dimN,
real scale); real scale);
/**
* @brief Matrix rotation in 90 degrees
*
* @param[in] mat input matrix (M x N).
* @param[out] matRot output matrix (N x M).
* @param[in] dimM input matrix height.
* @param[in] dimN input matrix width.
* @param[in] clockWise rotation direction
*/
extern void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise);
#endif /* HL_MATRIX_H_ */ #endif /* HL_MATRIX_H_ */
...@@ -106,4 +106,8 @@ inline void hl_matrix_collect_shared_bias(real* B_d, ...@@ -106,4 +106,8 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
const int dimM, const int dimM,
const int dimN, const int dimN,
real scale) {} real scale) {}
inline void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise);
#endif // HL_MATRIX_STUB_H_ #endif // HL_MATRIX_STUB_H_
...@@ -840,3 +840,28 @@ void hl_matrix_collect_shared_bias(real* B_d, ...@@ -840,3 +840,28 @@ void hl_matrix_collect_shared_bias(real* B_d,
(B_d, A_d, channel, dimM, dimN, dim, limit, scale); (B_d, A_d, channel, dimM, dimN, dim, limit, scale);
CHECK_SYNC("hl_matrix_collect_shared_bias failed"); CHECK_SYNC("hl_matrix_collect_shared_bias failed");
} }
__global__ void keMatrixRotate(real* mat, real* matRot,
int dimM, int dimN, bool clockWise) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < dimM * dimN) {
int i = idx / dimN;
int j = idx % dimN;
if (clockWise) {
matRot[j * dimM + i] = mat[(dimM - i - 1) * dimN + j];
} else {
matRot[j * dimM + i] = mat[i * dimN + (dimN - j - 1)];
}
}
}
void hl_matrix_rotate(real *mat, real* matRot,
int dimM, int dimN, bool clockWise) {
CHECK_NOTNULL(mat);
CHECK_NOTNULL(matRot);
const int threads = 512;
const int blocks = DIVUP(dimM * dimN, threads);
keMatrixRotate<<< blocks, threads, 0, STREAM_DEFAULT >>>
(mat, matRot, dimM, dimN, clockWise);
CHECK_SYNC("hl_matrix_rotate failed");
}
...@@ -23,7 +23,8 @@ bool RotateLayer::init(const LayerMap& layerMap, ...@@ -23,7 +23,8 @@ bool RotateLayer::init(const LayerMap& layerMap,
Layer::init(layerMap, parameterMap); Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 1UL); CHECK_EQ(inputLayers_.size(), 1UL);
sampleHeight_ = config_.height(); height_ = config_.height();
width_ = config_.width();
return true; return true;
} }
...@@ -32,26 +33,31 @@ void RotateLayer::forward(PassType passType) { ...@@ -32,26 +33,31 @@ void RotateLayer::forward(PassType passType) {
MatrixPtr input = getInputValue(0); MatrixPtr input = getInputValue(0);
batchSize_ = input->getHeight(); batchSize_ = input->getHeight();
sampleSize_ = input->getWidth(); size_ = input->getWidth();
sampleWidth_ = sampleSize_ / sampleHeight_; CHECK_GE(size_, height_ * width_);
CHECK_EQ(sampleSize_ % sampleHeight_, 0); CHECK_EQ(size_ % (height_ * width_), 0)
<< "The input's depth should be an int";
channels_ = size_ / (height_ * width_);
resizeOutput(batchSize_, sampleSize_); resizeOutput(batchSize_, size_);
MatrixPtr outV = getOutputValue(); MatrixPtr outV = getOutputValue();
for (int b = 0; b < batchSize_; b++) { // for each input feat map
for (int b = 0; b < batchSize_; b++) { for (int c = 0; c < channels_; c++) { // for each feat channel
MatrixPtr inputSample = Matrix::create(input->getData() + b * sampleSize_, MatrixPtr inputSample =
sampleHeight_, Matrix::create(input->getData() + b * size_ + c * height_ * width_,
sampleWidth_, height_,
false, width_,
useGpu_); false,
MatrixPtr outputSample = Matrix::create(outV->getData() + b * sampleSize_, useGpu_);
sampleWidth_, MatrixPtr outputSample =
sampleHeight_, Matrix::create(outV->getData() + b * size_ + c * height_ * width_,
false, width_,
useGpu_); height_,
inputSample->rotate(outputSample, false, true); false,
useGpu_);
inputSample->rotate(outputSample, false, true /* clock-wise */);
}
} }
if (getInputGrad(0)) { if (getInputGrad(0)) {
...@@ -69,23 +75,24 @@ void RotateLayer::backward(const UpdateCallback& callback) { ...@@ -69,23 +75,24 @@ void RotateLayer::backward(const UpdateCallback& callback) {
// the grad should be rotated in the reverse direction // the grad should be rotated in the reverse direction
MatrixPtr preGrad = getInputGrad(0); MatrixPtr preGrad = getInputGrad(0);
for (int b = 0; b < batchSize_; b++) { for (int b = 0; b < batchSize_; b++) { // for each input feat map
MatrixPtr inputSampleGrad = for (int c = 0; c < channels_; c++) { // for each feat channel
Matrix::create(preGrad->getData() + b * sampleSize_, MatrixPtr inputSampleGrad =
sampleHeight_, Matrix::create(preGrad->getData() + b * size_ + c * height_ * width_,
sampleWidth_, height_,
false, width_,
useGpu_); false,
MatrixPtr outputSampleGrad = useGpu_);
Matrix::create(outputGrad->getData() + b * sampleSize_, MatrixPtr outputSampleGrad = Matrix::create(
sampleWidth_, outputGrad->getData() + b * size_ + c * height_ * width_,
sampleHeight_, width_,
false, height_,
useGpu_); false,
MatrixPtr tmpGrad = useGpu_);
Matrix::create(sampleHeight_, sampleWidth_, false, useGpu_); MatrixPtr tmpGrad = nullptr;
outputSampleGrad->rotate(tmpGrad, false, false); outputSampleGrad->rotate(tmpGrad, true, false /* anti clock-wise */);
inputSampleGrad->add(*tmpGrad); inputSampleGrad->add(*tmpGrad);
}
} }
} }
......
...@@ -19,12 +19,13 @@ limitations under the License. */ ...@@ -19,12 +19,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
/** /**
* A layer for rotating an input sample (assume it's a matrix) * A layer for rotating a multi-channel feature map (M x N x C) in the spatial
* The rotation is in clock-wise * domain
* The rotation is 90 degrees in clock-wise
* \f[ * \f[
* y(j,i) = x(M-i-1,j) * y(j,i,:) = x(M-i-1,j,:)
* \f] * \f]
* where \f$x\f$ is (M x N) input, and \f$y\f$ is (N x M) output. * where \f$x\f$ is (M x N x C) input, and \f$y\f$ is (N x M x C) output.
* *
* The config file api is rotate_layer * The config file api is rotate_layer
* *
...@@ -41,9 +42,10 @@ public: ...@@ -41,9 +42,10 @@ public:
private: private:
int batchSize_; int batchSize_;
int sampleSize_; int size_;
int sampleHeight_; int height_;
int sampleWidth_; int width_;
int channels_;
}; };
} // namespace paddle } // namespace paddle
...@@ -1320,9 +1320,12 @@ TEST(Layer, RotateLayer) { ...@@ -1320,9 +1320,12 @@ TEST(Layer, RotateLayer) {
TestConfig config; TestConfig config;
config.biasSize = 0; config.biasSize = 0;
config.layerConfig.set_type("rotate"); config.layerConfig.set_type("rotate");
const int INPUT_SIZE = 64; // height * width const int INPUT_SIZE = 64; // height * width * depth
const int HEIGHT = 8;
const int WIDTH = 4;
config.layerConfig.set_size(INPUT_SIZE); config.layerConfig.set_size(INPUT_SIZE);
config.layerConfig.set_height(32); config.layerConfig.set_height(HEIGHT);
config.layerConfig.set_width(WIDTH);
config.inputDefs.push_back({INPUT_DATA, "layer_0", INPUT_SIZE, 0}); config.inputDefs.push_back({INPUT_DATA, "layer_0", INPUT_SIZE, 0});
config.layerConfig.add_inputs(); config.layerConfig.add_inputs();
......
...@@ -388,6 +388,8 @@ void GpuMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) { ...@@ -388,6 +388,8 @@ void GpuMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
matTrans = std::make_shared<GpuMatrix>(width_, height_); matTrans = std::make_shared<GpuMatrix>(width_, height_);
} else { } else {
CHECK(matTrans != NULL); CHECK(matTrans != NULL);
CHECK_EQ(matTrans->getHeight(), width_);
CHECK_EQ(matTrans->getWidth(), height_);
} }
real* dataTrans = matTrans->getData(); real* dataTrans = matTrans->getData();
real* data = getData(); real* data = getData();
...@@ -402,15 +404,13 @@ void GpuMatrix::rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) { ...@@ -402,15 +404,13 @@ void GpuMatrix::rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
matRot = std::make_shared<GpuMatrix>(width_, height_); matRot = std::make_shared<GpuMatrix>(width_, height_);
} else { } else {
CHECK(matRot != NULL); CHECK(matRot != NULL);
CHECK_EQ(matRot->getHeight(), width_);
CHECK_EQ(matRot->getWidth(), height_);
} }
MatrixPtr cpuMat = std::make_shared<CpuMatrix>(height_, width_); real* dataRot = matRot->getData();
cpuMat->copyFrom(*this); real* data = getData();
hl_matrix_rotate(data, dataRot, height_, width_, clockWise);
MatrixPtr cpuMatRot = std::make_shared<CpuMatrix>(width_, height_);
cpuMat->rotate(cpuMatRot, false, clockWise);
matRot->copyFrom(*cpuMatRot);
} }
MatrixPtr GpuMatrix::getInverse() { MatrixPtr GpuMatrix::getInverse() {
...@@ -1723,6 +1723,8 @@ void CpuMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) { ...@@ -1723,6 +1723,8 @@ void CpuMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
matTrans = std::make_shared<CpuMatrix>(width_, height_); matTrans = std::make_shared<CpuMatrix>(width_, height_);
} else { } else {
CHECK(matTrans != NULL); CHECK(matTrans != NULL);
CHECK_EQ(matTrans->getHeight(), width_);
CHECK_EQ(matTrans->getWidth(), height_);
} }
real* dataTrans = matTrans->getData(); real* dataTrans = matTrans->getData();
real* data = getData(); real* data = getData();
...@@ -1741,18 +1743,18 @@ void CpuMatrix::rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) { ...@@ -1741,18 +1743,18 @@ void CpuMatrix::rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
matRot = std::make_shared<CpuMatrix>(width_, height_); matRot = std::make_shared<CpuMatrix>(width_, height_);
} else { } else {
CHECK(matRot != NULL); CHECK(matRot != NULL);
CHECK_EQ(matRot->getHeight(), width_);
CHECK_EQ(matRot->getWidth(), height_);
} }
real* dataRot = matRot->getData(); real* dataRot = matRot->getData();
real* data = getData(); real* data = getData();
int lda = getStride();
int ldc = matRot->getStride();
for (size_t i = 0; i < height_; i++) { for (size_t i = 0; i < height_; i++) {
for (size_t j = 0; j < width_; j++) { for (size_t j = 0; j < width_; j++) {
if (clockWise) { if (clockWise) {
dataRot[j * ldc + i] = data[(height_ - i - 1) * lda + j]; dataRot[j * height_ + i] = data[(height_ - i - 1) * width_ + j];
} else { } else {
dataRot[j * ldc + i] = data[i * lda + (width_ - j - 1)]; dataRot[j * height_ + i] = data[i * width_ + (width_ - j - 1)];
} }
} }
} }
......
...@@ -377,9 +377,19 @@ public: ...@@ -377,9 +377,19 @@ public:
} }
/** /**
* @brief rotate clock-wise. * @brief rotate 90 degrees in clock-wise if clockWise=true;
* otherwise rotate in anti clock-wise
* clock-wise:
* \f[
* y(j,i) = x(M-i-1,j)
* \f]
* anti clock-wise:
* \f[
* y(j,i) = x(i, N-1-j)
* \f]
* where \f$x\f$ is (M x N) input, and \f$y\f$ is (N x M) output.
* *
* allocate matTrans' memory outside, then set memAlloc as false; * allocate matRot' memory outside, then set memAlloc as false;
* else set as true. * else set as true.
*/ */
virtual void rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) { virtual void rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
......
...@@ -176,11 +176,29 @@ void testMatrixTranspose(int height, int width) { ...@@ -176,11 +176,29 @@ void testMatrixTranspose(int height, int width) {
cpu->randomizeUniform(); cpu->randomizeUniform();
gpu->copyFrom(*cpu); gpu->copyFrom(*cpu);
cpu->transpose(cpuT, false); cpu->transpose(cpuT, false);
gpu->transpose(gpuT, false); gpu->transpose(gpuT, true);
TensorCheckEqual(*cpuT, *gpuT); TensorCheckEqual(*cpuT, *gpuT);
} }
void testMatrixRotate(int height, int width) {
MatrixPtr cpu = std::make_shared<CpuMatrix>(height, width);
MatrixPtr gpu = std::make_shared<GpuMatrix>(height, width);
MatrixPtr cpuR = std::make_shared<CpuMatrix>(width, height);
MatrixPtr gpuR = std::make_shared<GpuMatrix>(width, height);
cpu->randomizeUniform();
gpu->copyFrom(*cpu);
cpu->rotate(cpuR, false, true);
gpu->rotate(gpuR, true, true);
TensorCheckEqual(*cpuR, *gpuR);
cpu->rotate(cpuR, true, false);
gpu->rotate(gpuR, false, false);
TensorCheckEqual(*cpuR, *gpuR);
}
void testMatrixInverse(int height) { void testMatrixInverse(int height) {
MatrixPtr cpu = std::make_shared<CpuMatrix>(height, height); MatrixPtr cpu = std::make_shared<CpuMatrix>(height, height);
MatrixPtr gpu = std::make_shared<GpuMatrix>(height, height); MatrixPtr gpu = std::make_shared<GpuMatrix>(height, height);
...@@ -215,6 +233,7 @@ TEST(Matrix, unary) { ...@@ -215,6 +233,7 @@ TEST(Matrix, unary) {
testMatrixZeroAtOffset(height, width); testMatrixZeroAtOffset(height, width);
testMatrixGetSum(height, width); testMatrixGetSum(height, width);
testMatrixTranspose(height, width); testMatrixTranspose(height, width);
testMatrixRotate(height, width);
} }
// inverse // inverse
testMatrixInverse(height); testMatrixInverse(height);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册