提交 187e23a7 编写于 作者: C chengduoZH

fix MatMul parameter

上级 2a22da6c
...@@ -187,7 +187,8 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -187,7 +187,8 @@ class GemmConvKernel : public framework::OpKernel<T> {
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(filter_slice, col_matrix, &out_slice); blas.MatMul(filter_slice, false, col_matrix, false, T(1.0), &out_slice,
T(0.0));
} }
} }
} }
...@@ -304,7 +305,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -304,7 +305,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(in_grad_slice); col_matrix.ShareDataWith(in_grad_slice);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
blas.MatMul(filter_slice, true, out_grad_slice, false, &col_matrix); blas.MatMul(filter_slice, true, out_grad_slice, false, T(1.0),
&col_matrix, T(0.0));
if (is_expand && data_dim == 2U) { if (is_expand && data_dim == 2U) {
col2im(dev_ctx, col, dilations, strides, col2im(dev_ctx, col, dilations, strides,
...@@ -351,8 +353,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -351,8 +353,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// gemm // gemm
Tensor filter_grad_slice = Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step); filter_grad_.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(out_grad_slice, false, col_matrix, true, blas.MatMul(out_grad_slice, false, col_matrix, true, T(1.0),
&filter_grad_slice); &filter_grad_slice, T(1.0));
} }
} }
} }
......
...@@ -135,7 +135,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -135,7 +135,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// col_matrix = filter * input_batch // col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
blas.MatMul(filter, true, input_batch, false, &col_matrix); blas.MatMul(filter, true, input_batch, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0));
if (data_dim == 2U) { if (data_dim == 2U) {
// col2im: col_matrix -> dy // col2im: col_matrix -> dy
...@@ -267,7 +268,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -267,7 +268,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or // or
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w) // d, h, w)
blas.MatMul(filter, false, col_matrix, false, &input_grad_batch); blas.MatMul(filter, false, col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0));
} }
if (filter_grad) { if (filter_grad) {
// input batch // input batch
...@@ -277,7 +279,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -277,7 +279,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or // or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w) // k_h * k_w)
blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_); blas.MatMul(in_batch, false, col_matrix, true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0));
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册