提交 43f6cdc8 编写于 作者: C chengduoZH

fix Matrix

上级 c792ef7d
......@@ -1389,51 +1389,71 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) {
output_d, grad_d, mat_d, height_, width_);
}
void GpuMatrix::vol2Col(real* data,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW) {
hl_matrix_vol2Col(data,
channels, depth, height, width,
filterD, filterH, filterW,
strideD, strideH, strideW,
paddingD, paddingH, paddingW, getData());
}
void GpuMatrix::col2Vol(real* trg,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta) {
hl_matrix_col2Vol(trg,
channels, depth, height, width,
filterD, filterH, filterW,
strideD, strideH, strideW,
paddingD, paddingH, paddingW,
void GpuMatrix::vol2Col(real* dataSrc,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW) {
hl_matrix_vol2Col(dataSrc,
channels,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
getData());
}
void GpuMatrix::col2Vol(real* dataDst,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta) {
hl_matrix_col2Vol(dataDst,
channels,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
getData(),
alpha, beta);
}
alpha,
beta);
}
/**
* CpuMatrix
......@@ -4082,7 +4102,7 @@ void CpuMatrix::col2Vol(real* trg,
real alpha,
real beta) {
real* src = getData();
int outDepth = (depth + 2 * paddingH - filterD) / strideD + 1;
int outDepth = (depth + 2 * paddingD - filterD) / strideD + 1;
int outHeight = (height + 2 * paddingH - filterH) / strideH + 1;
int outWidth = (width + 2 * paddingW - filterW) / strideW + 1;
int channelsCol = channels * filterD * filterH * filterW;
......
......@@ -1040,40 +1040,40 @@ public:
}
virtual void vol2Col(real* data,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW) {
LOG(FATAL) << "Not implemeted";
}
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW) {
LOG(FATAL) << "Not implemeted";
}
virtual void col2Vol(real* trg,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta) {
LOG(FATAL) << "Not implemeted";
}
virtual void col2Vol(real* trg,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta) {
LOG(FATAL) << "Not implemeted";
}
virtual void bilinearForward(const Matrix& in,
const size_t inImgH,
......@@ -1411,18 +1411,36 @@ public:
const real ratioW);
void vol2Col(real* data,
int channels,
int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW);
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW);
void col2Vol(real* trg,
int channels, int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
real alpha, real beta);
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta);
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
......@@ -1767,17 +1785,35 @@ public:
void vol2Col(real* data,
int channels,
int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW);
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW);
void col2Vol(real* trg,
int channels, int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
real alpha, real beta);
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta);
template <typename ExpressionType>
void operator=(const ExpressionType& expr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册