提交 43f6cdc8 编写于 作者: C chengduoZH

fix Matrix

上级 c792ef7d
...@@ -1389,51 +1389,71 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) { ...@@ -1389,51 +1389,71 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) {
output_d, grad_d, mat_d, height_, width_); output_d, grad_d, mat_d, height_, width_);
} }
void GpuMatrix::vol2Col(real* data, void GpuMatrix::vol2Col(real* dataSrc,
int channels, int channels,
int depth, int depth,
int height, int height,
int width, int width,
int filterD, int filterD,
int filterH, int filterH,
int filterW, int filterW,
int strideD, int strideD,
int strideH, int strideH,
int strideW, int strideW,
int paddingD, int paddingD,
int paddingH, int paddingH,
int paddingW) { int paddingW) {
hl_matrix_vol2Col(data, hl_matrix_vol2Col(dataSrc,
channels, depth, height, width, channels,
filterD, filterH, filterW, depth,
strideD, strideH, strideW, height,
paddingD, paddingH, paddingW, getData()); width,
} filterD,
filterH,
void GpuMatrix::col2Vol(real* trg, filterW,
int channels, strideD,
int depth, strideH,
int height, strideW,
int width, paddingD,
int filterD, paddingH,
int filterH, paddingW,
int filterW, getData());
int strideD, }
int strideH,
int strideW, void GpuMatrix::col2Vol(real* dataDst,
int paddingD, int channels,
int paddingH, int depth,
int paddingW, int height,
real alpha, int width,
real beta) { int filterD,
hl_matrix_col2Vol(trg, int filterH,
channels, depth, height, width, int filterW,
filterD, filterH, filterW, int strideD,
strideD, strideH, strideW, int strideH,
paddingD, paddingH, paddingW, int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta) {
hl_matrix_col2Vol(dataDst,
channels,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
getData(), getData(),
alpha, beta); alpha,
} beta);
}
/** /**
* CpuMatrix * CpuMatrix
...@@ -4082,7 +4102,7 @@ void CpuMatrix::col2Vol(real* trg, ...@@ -4082,7 +4102,7 @@ void CpuMatrix::col2Vol(real* trg,
real alpha, real alpha,
real beta) { real beta) {
real* src = getData(); real* src = getData();
int outDepth = (depth + 2 * paddingH - filterD) / strideD + 1; int outDepth = (depth + 2 * paddingD - filterD) / strideD + 1;
int outHeight = (height + 2 * paddingH - filterH) / strideH + 1; int outHeight = (height + 2 * paddingH - filterH) / strideH + 1;
int outWidth = (width + 2 * paddingW - filterW) / strideW + 1; int outWidth = (width + 2 * paddingW - filterW) / strideW + 1;
int channelsCol = channels * filterD * filterH * filterW; int channelsCol = channels * filterD * filterH * filterW;
......
...@@ -1040,40 +1040,40 @@ public: ...@@ -1040,40 +1040,40 @@ public:
} }
virtual void vol2Col(real* data, virtual void vol2Col(real* data,
int channels, int channels,
int depth, int depth,
int height, int height,
int width, int width,
int filterD, int filterD,
int filterH, int filterH,
int filterW, int filterW,
int strideD, int strideD,
int strideH, int strideH,
int strideW, int strideW,
int paddingD, int paddingD,
int paddingH, int paddingH,
int paddingW) { int paddingW) {
LOG(FATAL) << "Not implemeted"; LOG(FATAL) << "Not implemeted";
} }
virtual void col2Vol(real* trg, virtual void col2Vol(real* trg,
int channels, int channels,
int depth, int depth,
int height, int height,
int width, int width,
int filterD, int filterD,
int filterH, int filterH,
int filterW, int filterW,
int strideD, int strideD,
int strideH, int strideH,
int strideW, int strideW,
int paddingD, int paddingD,
int paddingH, int paddingH,
int paddingW, int paddingW,
real alpha, real alpha,
real beta) { real beta) {
LOG(FATAL) << "Not implemeted"; LOG(FATAL) << "Not implemeted";
} }
virtual void bilinearForward(const Matrix& in, virtual void bilinearForward(const Matrix& in,
const size_t inImgH, const size_t inImgH,
...@@ -1411,18 +1411,36 @@ public: ...@@ -1411,18 +1411,36 @@ public:
const real ratioW); const real ratioW);
void vol2Col(real* data, void vol2Col(real* data,
int channels, int channels,
int depth, int height, int width, int depth,
int filterD, int filterH, int filterW, int height,
int strideD, int strideH, int strideW, int width,
int paddingD, int paddingH, int paddingW); int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW);
void col2Vol(real* trg, void col2Vol(real* trg,
int channels, int depth, int height, int width, int channels,
int filterD, int filterH, int filterW, int depth,
int strideD, int strideH, int strideW, int height,
int paddingD, int paddingH, int paddingW, int width,
real alpha, real beta); int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta);
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label); void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
...@@ -1767,17 +1785,35 @@ public: ...@@ -1767,17 +1785,35 @@ public:
void vol2Col(real* data, void vol2Col(real* data,
int channels, int channels,
int depth, int height, int width, int depth,
int filterD, int filterH, int filterW, int height,
int strideD, int strideH, int strideW, int width,
int paddingD, int paddingH, int paddingW); int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW);
void col2Vol(real* trg, void col2Vol(real* trg,
int channels, int depth, int height, int width, int channels,
int filterD, int filterH, int filterW, int depth,
int strideD, int strideH, int strideW, int height,
int paddingD, int paddingH, int paddingW, int width,
real alpha, real beta); int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta);
template <typename ExpressionType> template <typename ExpressionType>
void operator=(const ExpressionType& expr) { void operator=(const ExpressionType& expr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册