提交 82aa5693 编写于 作者: C chengduoZH

follow comments

上级 1431f251
...@@ -63,29 +63,25 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> { ...@@ -63,29 +63,25 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
// The filter will be reshaped, so it should not be constant pointer // The filter will be reshaped, so it should not be constant pointer
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Output");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// TODO(Zhuoyuan): Paddings can be added in future. // TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose. // groups will alway be disabled in conv2dtranspose.
const int batch_size = input->dims()[0]; const int batch_size = static_cast<int>(input->dims()[0]);
const int m = input->dims()[1]; const int64_t m = input->dims()[1];
const int h = input->dims()[2]; const int64_t h = input->dims()[2];
const int w = input->dims()[3]; const int64_t w = input->dims()[3];
const int k_h = filter.dims()[2]; const int64_t k_h = filter.dims()[2];
const int k_w = filter.dims()[3]; const int64_t k_w = filter.dims()[3];
const int c = output->dims()[1]; // output channels const int64_t c = output->dims()[1]; // output channels
const int o_h = output->dims()[2]; const int64_t o_h = output->dims()[2];
const int o_w = output->dims()[3]; const int64_t o_w = output->dims()[3];
paddle::operators::math::Col2ImFunctor< math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
// use col_shape in the im2col and col2im calculation // use col_shape in the im2col and col2im calculation
DDim col_shape = {c, k_h, k_w, h, w}; DDim col_shape = {c, k_h, k_w, h, w};
...@@ -105,19 +101,18 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> { ...@@ -105,19 +101,18 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
DDim output_shape = {c, o_h, o_w}; DDim output_shape = {c, o_h, o_w};
DDim input_matrix_shape = {m, h * w}; DDim input_matrix_shape = {m, h * w};
// filter size: (m, c * k_h * k_w)
DDim filter_matrix_shape = {m, c * k_h * k_w}; DDim filter_matrix_shape = {m, c * k_h * k_w};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
// convolution transpose: gemm + col2im (similar to conv-backward on input)
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output); math::SetConstant<Place, T> set_zero;
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0)); set_zero(context.device_context(), output, static_cast<T>(0));
// convolution transpose: gemm + col2im (similar to conv-backward on input)
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
// batch with size (M, h * w) // batch with size (m, h * w)
Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// filter size: (M, c * k_h * k_w)
// output size: (c, o_h, o_w) // output size: (c, o_h, o_w)
Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);
...@@ -125,7 +120,11 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> { ...@@ -125,7 +120,11 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
// col_matrix = filter * input_batch // col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w) // of shape (c * k_h * k_w, h * w)
math::matmul<Place, T>(context.device_context(), filter, true, math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix, T(0.0)); input_batch, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0));
// col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
col2im(context.device_context(), output_batch, col, strides[0], col2im(context.device_context(), output_batch, col, strides[0],
strides[1], 0, 0, 0, 0); strides[1], 0, 0, 0, 0);
} }
...@@ -143,7 +142,6 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -143,7 +142,6 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
// For filter, we do not use const pointer b/c we will do reshape, // For filter, we do not use const pointer b/c we will do reshape,
// but we should avoid modifying its value. // but we should avoid modifying its value.
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
Tensor* input_grad = Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input")); context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = Tensor* filter_grad =
...@@ -153,35 +151,24 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -153,35 +151,24 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in conv transpose. // Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
const int batch_size = input->dims()[0]; const int batch_size = static_cast<int>(input->dims()[0]);
const int m = input->dims()[1]; const int64_t m = input->dims()[1];
const int h = input->dims()[2]; const int64_t h = input->dims()[2];
const int w = input->dims()[3]; const int64_t w = input->dims()[3];
const int k_h = filter.dims()[2]; const int64_t k_h = filter.dims()[2];
const int k_w = filter.dims()[3]; const int64_t k_w = filter.dims()[3];
const int c = output_grad->dims()[1]; // output channels const int64_t c = output_grad->dims()[1]; // output channels
const int o_h = output_grad->dims()[2]; const int64_t o_h = output_grad->dims()[2];
const int o_w = output_grad->dims()[3]; const int64_t o_w = output_grad->dims()[3];
// Only im2col functor required for bp to get to the right shape // Only im2col functor required for bp to get to the right shape
paddle::operators::math::Im2ColFunctor< math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
paddle::operators::math::ColFormat::kCFO, Place, T>
im2col;
// use col_shape in the im2col and col2im calculation // use col_shape in the im2col and col2im calculation
DDim col_shape = {c, k_h, k_w, h, w}; DDim col_shape = {c, k_h, k_w, h, w};
// use col_matrix_shape in the gemm calculation
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
DDim output_shape = {c, o_h, o_w}; DDim output_shape = {c, o_h, o_w};
DDim input_matrix_shape = {m, h * w}; DDim input_matrix_shape = {m, h * w};
...@@ -191,67 +178,60 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -191,67 +178,60 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
// convolution transpose grad on input: // convolution transpose grad on input:
// im2col + gemm (similar to conv-forward) // im2col + gemm (similar to conv-forward)
// input need to compute gradient // input need to compute gradient
if (input_grad) { if (input_grad || filter_grad) {
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix; Tensor col_matrix;
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
DDim col_matrix_shape = {c * k_h * k_w, h * w}; DDim col_matrix_shape = {c * k_h * k_w, h * w};
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
input_grad->mutable_data<T>(context.GetPlace()); Tensor filter_grad_;
auto t = framework::EigenVector<T>::Flatten(*input_grad); math::SetConstant<Place, T> set_zero;
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) {
// batch with size (c, o_h * o_w)
Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape);
// filter of size (m, c * k_h * k_w)
// batch with size (m, h, w)
Tensor input_grad_batch =
input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
// im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) if (input_grad) {
im2col(context.device_context(), output_grad_batch, col, strides[0], input_grad->mutable_data<T>(context.GetPlace());
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); set_zero(context.device_context(), input_grad, static_cast<T>(0));
}
// gemm: dx = filter * dy if (filter_grad) { // filter size (m, c, k_h, k_w)
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) filter_grad->mutable_data<T>(context.GetPlace());
math::matmul<Place, T>(context.device_context(), filter, false, set_zero(context.device_context(), filter_grad, static_cast<T>(0));
col_matrix, false, T(1.0), &input_grad_batch, filter_grad_ = *filter_grad;
T(0.0)); filter_grad_.Resize(filter_matrix_shape);
} }
}
// filter gradient required for (int i = 0; i < batch_size; i++) {
if (filter_grad) { // batch with size (c, o_h * o_w)
Tensor col_matrix_f;
col_matrix_f.ShareDataWith(col);
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
col_matrix_f.Resize(col_matrix_shape_f);
filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
auto t = framework::EigenVector<T>::Flatten(filter_grad_);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; ++i) {
// batch with size (c, o_h, o_w)
Tensor output_grad_batch = Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape); output_grad->Slice(i, i + 1).Resize(output_shape);
// input batch
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// im2col: (c * h * w, k_h * k_w) // im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch, col, strides[0], im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm: d_filter = x * y_grad^T if (input_grad) {
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) // batch with size (m, h, w)
math::matmul<Place, T>(context.device_context(), in_batch, false, Tensor input_grad_batch =
col_matrix_f, true, T(1.0), &filter_grad_, input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
T(1.0)); // gemm: dx = filter * dy
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, h * w)
math::matmul<Place, T>(context.device_context(), filter, false,
col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0));
}
if (filter_grad) {
// input batch
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// gemm: d_filter = x * dy^T
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, k_h * k_w)
math::matmul<Place, T>(context.device_context(), in_batch, false,
col_matrix, true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0));
}
} }
} }
} }
...@@ -267,30 +247,28 @@ class GemmConv3DTransposeKernel : public framework::OpKernel<T> { ...@@ -267,30 +247,28 @@ class GemmConv3DTransposeKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Output");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// TODO(chengduo): Paddings can be added in future. // TODO(chengduo): Paddings can be added in future.
// groups will alway be disabled in conv3dtranspose. // groups will alway be disabled in conv3dtranspose.
const int batch_size = input->dims()[0]; const int batch_size = static_cast<int>(input->dims()[0]);
const int m = input->dims()[1]; const int64_t m = input->dims()[1];
const int d = input->dims()[2]; const int64_t d = input->dims()[2];
const int h = input->dims()[3]; const int64_t h = input->dims()[3];
const int w = input->dims()[4]; const int64_t w = input->dims()[4];
const int k_d = filter.dims()[2]; const int64_t k_d = filter.dims()[2];
const int k_h = filter.dims()[3]; const int64_t k_h = filter.dims()[3];
const int k_w = filter.dims()[4]; const int64_t k_w = filter.dims()[4];
const int c = output->dims()[1]; // output channels const int64_t c = output->dims()[1]; // output channels
const int o_d = output->dims()[2]; const int64_t o_d = output->dims()[2];
const int o_h = output->dims()[3]; const int64_t o_h = output->dims()[3];
const int o_w = output->dims()[4]; const int64_t o_w = output->dims()[4];
paddle::operators::math::Col2VolFunctor<Place, T> col2vol; paddle::operators::math::Col2VolFunctor<Place, T> col2vol;
// use col_shape in the vol2col and col2vol calculation // use col_shape in the vol2col and col2vol calculation
DDim col_shape = {c, k_d, k_h, k_w, d, h, w}; DDim col_shape = {c, k_d, k_h, k_w, d, h, w};
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w}; DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w};
...@@ -306,19 +284,18 @@ class GemmConv3DTransposeKernel : public framework::OpKernel<T> { ...@@ -306,19 +284,18 @@ class GemmConv3DTransposeKernel : public framework::OpKernel<T> {
DDim output_shape = {c, o_d, o_h, o_w}; DDim output_shape = {c, o_d, o_h, o_w};
DDim input_matrix_shape = {m, d * h * w}; DDim input_matrix_shape = {m, d * h * w};
// filter size: (m, c * k_d * k_h * k_w)
DDim filter_matrix_shape = {m, c * k_d * k_h * k_w}; DDim filter_matrix_shape = {m, c * k_d * k_h * k_w};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
// convolution transpose: gemm + col2vol (similar to conv-backward on input)
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output); math::SetConstant<Place, T> set_zero;
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0)); set_zero(context.device_context(), output, static_cast<T>(0));
// convolution transpose: gemm + col2vol (similar to conv-backward on input)
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
// batch with size (M, d * h * w) // batch with size (m, d * h * w)
Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// filter size: (M, c * k_d * k_h * k_w)
// output size: (c, o_d, o_h, o_w) // output size: (c, o_d, o_h, o_w)
Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);
...@@ -326,7 +303,10 @@ class GemmConv3DTransposeKernel : public framework::OpKernel<T> { ...@@ -326,7 +303,10 @@ class GemmConv3DTransposeKernel : public framework::OpKernel<T> {
// col_matrix = filter * input_batch // col_matrix = filter * input_batch
// of shape (c * k_d * k_h * k_w, d * h * w) // of shape (c * k_d * k_h * k_w, d * h * w)
math::matmul<Place, T>(context.device_context(), filter, true, math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix, T(0.0)); input_batch, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0));
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
col2vol(context.device_context(), output_batch, col, strides[0], col2vol(context.device_context(), output_batch, col, strides[0],
strides[1], strides[2], 0, 0, 0); strides[1], strides[2], 0, 0, 0);
} }
...@@ -344,7 +324,6 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -344,7 +324,6 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> {
// For filter, we do not use const pointer b/c we will do reshape, // For filter, we do not use const pointer b/c we will do reshape,
// but we should avoid modifying its value. // but we should avoid modifying its value.
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
Tensor* input_grad = Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input")); context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = Tensor* filter_grad =
...@@ -354,20 +333,20 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -354,20 +333,20 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in conv transpose. // Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
const int batch_size = input->dims()[0]; const int batch_size = static_cast<int>(input->dims()[0]);
const int m = input->dims()[1]; const int64_t m = input->dims()[1];
const int d = input->dims()[2]; const int64_t d = input->dims()[2];
const int h = input->dims()[3]; const int64_t h = input->dims()[3];
const int w = input->dims()[4]; const int64_t w = input->dims()[4];
const int k_d = filter.dims()[2]; const int64_t k_d = filter.dims()[2];
const int k_h = filter.dims()[3]; const int64_t k_h = filter.dims()[3];
const int k_w = filter.dims()[4]; const int64_t k_w = filter.dims()[4];
const int c = output_grad->dims()[1]; // output channels const int64_t c = output_grad->dims()[1]; // output channels
const int o_d = output_grad->dims()[2]; const int64_t o_d = output_grad->dims()[2];
const int o_h = output_grad->dims()[3]; const int64_t o_h = output_grad->dims()[3];
const int o_w = output_grad->dims()[4]; const int64_t o_w = output_grad->dims()[4];
// Only vol2col functor required for bp to get to the right shape // Only vol2col functor required for bp to get to the right shape
paddle::operators::math::Vol2ColFunctor<Place, T> vol2col; paddle::operators::math::Vol2ColFunctor<Place, T> vol2col;
...@@ -378,12 +357,6 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -378,12 +357,6 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> {
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
DDim col_matrix_shape_f = {c * d * h * w, k_d * k_h * k_w}; DDim col_matrix_shape_f = {c * d * h * w, k_d * k_h * k_w};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
DDim output_shape = {c, o_d, o_h, o_w}; DDim output_shape = {c, o_d, o_h, o_w};
DDim input_matrix_shape = {m, d * h * w}; DDim input_matrix_shape = {m, d * h * w};
...@@ -393,70 +366,62 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -393,70 +366,62 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> {
// convolution transpose grad on input: // convolution transpose grad on input:
// vol2col + gemm (similar to conv-forward) // vol2col + gemm (similar to conv-forward)
// input need to compute gradient // input need to compute gradient
if (input_grad) { if (input_grad || filter_grad) {
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix; Tensor col_matrix;
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w}; DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w};
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
input_grad->mutable_data<T>(context.GetPlace()); Tensor filter_grad_;
auto t = framework::EigenVector<T>::Flatten(*input_grad); math::SetConstant<Place, T> set_zero;
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) { if (input_grad) {
// batch with size (c, o_d * o_h * o_w) input_grad->mutable_data<T>(context.GetPlace());
Tensor output_grad_batch = set_zero(context.device_context(), input_grad, static_cast<T>(0));
output_grad->Slice(i, i + 1).Resize(output_shape); }
// filter of size (m, c * k_d * k_h * k_w) if (filter_grad) { // filter size (m, c * k_d * k_h * k_w)
filter_grad->mutable_data<T>(context.GetPlace());
// batch with size (m, d, h, w) set_zero(context.device_context(), filter_grad, static_cast<T>(0));
Tensor input_grad_batch = filter_grad_ = *filter_grad;
input_grad->Slice(i, i + 1).Resize(input_matrix_shape); filter_grad_.Resize(filter_matrix_shape);
// vol2col: dy from (c, o_d, o_h, o_w) -> (c * k_d * k_h * k_w, d * h *
// w)
vol2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], strides[2], paddings[0], paddings[1], paddings[2]);
// gemm: dx = filter * dy
// (m, c *k_d * k_h * k_w) * (c * k_d * k_h * k_w, d* h * w) -> (m, c,
// d, h, w)
math::matmul<Place, T>(context.device_context(), filter, false,
col_matrix, false, T(1.0), &input_grad_batch,
T(0.0));
} }
}
// filter gradient required for (int i = 0; i < batch_size; i++) {
if (filter_grad) { // batch with size (c, o_d * o_h * o_w)
Tensor col_matrix_f;
col_matrix_f.ShareDataWith(col);
DDim col_matrix_shape_f = {c * d * h * w, k_d * k_h * k_w};
col_matrix_f.Resize(col_matrix_shape_f);
filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
auto t = framework::EigenVector<T>::Flatten(filter_grad_);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; ++i) {
// batch with size (c, o_d, o_h, o_w)
Tensor output_grad_batch = Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape); output_grad->Slice(i, i + 1).Resize(output_shape);
// input batch
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// vol2col: (c * d * h * w, k_d * k_h * k_w) // vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
vol2col(context.device_context(), output_grad_batch, col, strides[0], vol2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], strides[2], paddings[0], paddings[1], paddings[2]); strides[1], strides[2], paddings[0], paddings[1], paddings[2]);
// gemm: d_filter = x * y_grad^T if (input_grad) {
// (m, c * d * h * w) * (k_d * k_h * k_w, c * d * h * w) -> (m, c, d, h, // batch with size (m, d, h, w)
// w) Tensor input_grad_batch =
math::matmul<Place, T>(context.device_context(), in_batch, false, input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
col_matrix_f, true, T(1.0), &filter_grad_, // gemm: dx = filter * dy
T(1.0)); // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w)
math::matmul<Place, T>(context.device_context(), filter, false,
col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0));
}
if (filter_grad) {
// input batch
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// gemm: d_filter = x * dy^T
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w)
math::matmul<Place, T>(context.device_context(), in_batch, false,
col_matrix, true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0));
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册