提交 43f6cdc8 编写于 作者: C chengduoZH

fix Matrix

上级 c792ef7d
...@@ -1389,7 +1389,7 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) { ...@@ -1389,7 +1389,7 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) {
output_d, grad_d, mat_d, height_, width_); output_d, grad_d, mat_d, height_, width_);
} }
void GpuMatrix::vol2Col(real* data, void GpuMatrix::vol2Col(real* dataSrc,
int channels, int channels,
int depth, int depth,
int height, int height,
...@@ -1403,14 +1403,24 @@ void GpuMatrix::vol2Col(real* data, ...@@ -1403,14 +1403,24 @@ void GpuMatrix::vol2Col(real* data,
int paddingD, int paddingD,
int paddingH, int paddingH,
int paddingW) { int paddingW) {
hl_matrix_vol2Col(data, hl_matrix_vol2Col(dataSrc,
channels, depth, height, width, channels,
filterD, filterH, filterW, depth,
strideD, strideH, strideW, height,
paddingD, paddingH, paddingW, getData()); width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
getData());
} }
void GpuMatrix::col2Vol(real* trg, void GpuMatrix::col2Vol(real* dataDst,
int channels, int channels,
int depth, int depth,
int height, int height,
...@@ -1426,14 +1436,24 @@ void GpuMatrix::col2Vol(real* trg, ...@@ -1426,14 +1436,24 @@ void GpuMatrix::col2Vol(real* trg,
int paddingW, int paddingW,
real alpha, real alpha,
real beta) { real beta) {
hl_matrix_col2Vol(trg, hl_matrix_col2Vol(dataDst,
channels, depth, height, width, channels,
filterD, filterH, filterW, depth,
strideD, strideH, strideW, height,
paddingD, paddingH, paddingW, width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
getData(), getData(),
alpha, beta); alpha,
} beta);
}
/** /**
* CpuMatrix * CpuMatrix
...@@ -4082,7 +4102,7 @@ void CpuMatrix::col2Vol(real* trg, ...@@ -4082,7 +4102,7 @@ void CpuMatrix::col2Vol(real* trg,
real alpha, real alpha,
real beta) { real beta) {
real* src = getData(); real* src = getData();
int outDepth = (depth + 2 * paddingH - filterD) / strideD + 1; int outDepth = (depth + 2 * paddingD - filterD) / strideD + 1;
int outHeight = (height + 2 * paddingH - filterH) / strideH + 1; int outHeight = (height + 2 * paddingH - filterH) / strideH + 1;
int outWidth = (width + 2 * paddingW - filterW) / strideW + 1; int outWidth = (width + 2 * paddingW - filterW) / strideW + 1;
int channelsCol = channels * filterD * filterH * filterW; int channelsCol = channels * filterD * filterH * filterW;
......
...@@ -1412,17 +1412,35 @@ public: ...@@ -1412,17 +1412,35 @@ public:
void vol2Col(real* data, void vol2Col(real* data,
int channels, int channels,
int depth, int height, int width, int depth,
int filterD, int filterH, int filterW, int height,
int strideD, int strideH, int strideW, int width,
int paddingD, int paddingH, int paddingW); int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW);
void col2Vol(real* trg, void col2Vol(real* trg,
int channels, int depth, int height, int width, int channels,
int filterD, int filterH, int filterW, int depth,
int strideD, int strideH, int strideW, int height,
int paddingD, int paddingH, int paddingW, int width,
real alpha, real beta); int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta);
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label); void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
...@@ -1767,17 +1785,35 @@ public: ...@@ -1767,17 +1785,35 @@ public:
void vol2Col(real* data, void vol2Col(real* data,
int channels, int channels,
int depth, int height, int width, int depth,
int filterD, int filterH, int filterW, int height,
int strideD, int strideH, int strideW, int width,
int paddingD, int paddingH, int paddingW); int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW);
void col2Vol(real* trg, void col2Vol(real* trg,
int channels, int depth, int height, int width, int channels,
int filterD, int filterH, int filterW, int depth,
int strideD, int strideH, int strideW, int height,
int paddingD, int paddingH, int paddingW, int width,
real alpha, real beta); int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta);
template <typename ExpressionType> template <typename ExpressionType>
void operator=(const ExpressionType& expr) { void operator=(const ExpressionType& expr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册