提交 e8cd4b7d 编写于 作者: Z zchen0211

deconv2d impl in full

上级 43aad989
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/conv2d_op.h"
#include "paddle/operators/deconv2d_op.h" #include "paddle/operators/deconv2d_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -158,9 +158,6 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -158,9 +158,6 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
int O_W = output_grad->dims()[3]; int O_W = output_grad->dims()[3];
// Two functors required to get to the right shape // Two functors required to get to the right shape
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
paddle::operators::math::Im2ColFunctor< paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T> paddle::operators::math::ColFormat::kCFO, Place, T>
im2col; im2col;
...@@ -231,7 +228,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -231,7 +228,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
strides[0], strides[1], paddings[0], paddings[1]); strides[0], strides[1], paddings[0], paddings[1]);
// gemm: d_filter = x * y_grad^T // gemm: d_filter = x * y_grad^T
math::matmul<Place, T>(context.device_context(), in_batch, false, math::matmul<Place, T>(context.device_context(), in_batch, false,
col_matrix, true, T(1.0), &filter_grad, T(1.0)); col_matrix, true, T(1.0), &filter_grad_, T(1.0));
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册