From 187e23a79cd4f60c8d59de50207857921f6423b8 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 8 May 2018 12:50:20 +0800 Subject: [PATCH] fix MatMul parameter --- paddle/fluid/operators/conv_op.h | 10 ++++++---- paddle/fluid/operators/conv_transpose_op.h | 9 ++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index c51898abb42..f462f00c080 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -187,7 +187,8 @@ class GemmConvKernel : public framework::OpKernel { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - blas.MatMul(filter_slice, col_matrix, &out_slice); + blas.MatMul(filter_slice, false, col_matrix, false, T(1.0), &out_slice, + T(0.0)); } } } @@ -304,7 +305,8 @@ class GemmConvGradKernel : public framework::OpKernel { col_matrix.ShareDataWith(in_grad_slice); col_matrix.Resize(col_matrix_shape); } - blas.MatMul(filter_slice, true, out_grad_slice, false, &col_matrix); + blas.MatMul(filter_slice, true, out_grad_slice, false, T(1.0), + &col_matrix, T(0.0)); if (is_expand && data_dim == 2U) { col2im(dev_ctx, col, dilations, strides, @@ -351,8 +353,8 @@ class GemmConvGradKernel : public framework::OpKernel { // gemm Tensor filter_grad_slice = filter_grad_.Slice(g * out_step, (g + 1) * out_step); - blas.MatMul(out_grad_slice, false, col_matrix, true, - &filter_grad_slice); + blas.MatMul(out_grad_slice, false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); } } } diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index 9276e5bfef7..898121412b1 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -135,7 +135,8 @@ class GemmConvTransposeKernel : public framework::OpKernel { // col_matrix = filter * input_batch // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) - blas.MatMul(filter, true, input_batch, false, &col_matrix); + blas.MatMul(filter, true, input_batch, false, static_cast(1.0), + &col_matrix, static_cast(0.0)); if (data_dim == 2U) { // col2im: col_matrix -> dy @@ -267,7 +268,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // or // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, // d, h, w) - blas.MatMul(filter, false, col_matrix, false, &input_grad_batch); + blas.MatMul(filter, false, col_matrix, false, static_cast(1.0), + &input_grad_batch, static_cast(0.0)); } if (filter_grad) { // input batch @@ -277,7 +279,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // or // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * // k_h * k_w) - blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_); + blas.MatMul(in_batch, false, col_matrix, true, static_cast(1.0), + &filter_grad_, static_cast(1.0)); } } } -- GitLab