未验证 提交 e800c0d3 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #5791 from chengduoZH/fix_conv_op

remove vector::erase
...@@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim, ...@@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
std::vector<int>& dilations) { std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) { for (size_t j = 0; j < strides.size(); ++j) {
filter_1 = filter_1 && (static_cast<int>(filter_dim[j]) == 1); filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
strides_1 = strides_1 && (strides[j] == 1); strides_1 = strides_1 && (strides[j] == 1);
padding_0 = padding_0 && (paddings[j] == 0); padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1); dilation_1 = dilation_1 && (dilations[j] == 1);
...@@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
filter_shape_vec.erase(filter_shape_vec.begin(), // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
filter_shape_vec.begin() + 2);
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims())); std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
output_shape_vec.erase(output_shape_vec.begin(),
output_shape_vec.begin() + 2);
// use col_shape in the im2col calculation // use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w} // o_h, o_w}
std::vector<int64_t> col_shape_vec; size_t data_dim = filter_shape_vec.size() - 2;
col_shape_vec.push_back(input->dims()[1] / groups); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), col_shape_vec[0] = input->dims()[1] / groups;
filter_shape_vec.end()); for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(), col_shape_vec[j + 1] = filter_shape_vec[j + 2];
output_shape_vec.end()); col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec)); framework::DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
// o_h * o_w) // o_h * o_w)
framework::DDim col_matrix_shape = framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col; Tensor col;
...@@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
col.ShareDataWith(in_slice); col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} else if (filter_shape_vec.size() == 2) { } else if (data_dim == 2U) {
// im2col // im2col
im2col(context.device_context(), in_slice, dilations, strides, im2col(context.device_context(), in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&col); &col);
} else if (filter_shape_vec.size() == 3) { } else if (data_dim == 3U) {
// vol2col // vol2col
vol2col(context.device_context(), in_slice, dilations, strides, vol2col(context.device_context(), in_slice, dilations, strides,
paddings, &col); paddings, &col);
...@@ -206,26 +202,22 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -206,26 +202,22 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
filter_shape_vec.erase(filter_shape_vec.begin(), // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
filter_shape_vec.begin() + 2);
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
std::vector<int64_t> output_shape_vec( std::vector<int64_t> output_shape_vec(
framework::vectorize(output_grad->dims())); framework::vectorize(output_grad->dims()));
output_shape_vec.erase(output_shape_vec.begin(),
output_shape_vec.begin() + 2);
// use col_shape in the im2col calculation // use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w} // o_h, o_w}
std::vector<int64_t> col_shape_vec; size_t data_dim = filter_shape_vec.size() - 2;
col_shape_vec.push_back(input->dims()[1] / groups); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), col_shape_vec[0] = input->dims()[1] / groups;
filter_shape_vec.end()); for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(), col_shape_vec[j + 1] = filter_shape_vec[j + 2];
output_shape_vec.end()); col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec)); framework::DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
...@@ -233,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -233,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// or // or
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w) // (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
framework::DDim col_matrix_shape = framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); framework::flatten_to_2d(col_shape, data_dim + 1);
framework::DDim input_shape = framework::slice_ddim( framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size())); input->dims(), 1, static_cast<int>(input->dims().size()));
...@@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
out_grad_slice, false, T(1.0), &col_matrix, out_grad_slice, false, T(1.0), &col_matrix,
T(0.0)); T(0.0));
if (is_expand && filter_shape_vec.size() == 2) { if (is_expand && data_dim == 2U) {
col2im(context.device_context(), col, dilations, strides, col2im(context.device_context(), col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&in_grad_slice); &in_grad_slice);
} else if (is_expand && filter_shape_vec.size() == 3) { } else if (is_expand && data_dim == 3U) {
col2vol(context.device_context(), col, dilations, strides, paddings, col2vol(context.device_context(), col, dilations, strides, paddings,
&in_grad_slice); &in_grad_slice);
} }
...@@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col.ShareDataWith(in_slice); col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} else if (filter_shape_vec.size() == 2) { } else if (data_dim == 2U) {
im2col(context.device_context(), in_slice, dilations, strides, im2col(context.device_context(), in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&col); &col);
} else if (filter_shape_vec.size() == 3) { } else if (data_dim == 3U) {
vol2col(context.device_context(), in_slice, dilations, strides, vol2col(context.device_context(), in_slice, dilations, strides,
paddings, &col); paddings, &col);
} }
......
...@@ -68,30 +68,26 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -68,30 +68,26 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w} // input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims()); std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2); // filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims()); std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
filter_shape_vec.erase(filter_shape_vec.begin(),
filter_shape_vec.begin() + 2);
// use col_shape in the im2col and col2im (or vol2col and col2vol) // use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation // calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
std::vector<int64_t> col_shape_vec; size_t data_dim = filter_shape_vec.size() - 2;
col_shape_vec.push_back(output->dims()[1]); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), col_shape_vec[0] = output->dims()[1];
filter_shape_vec.end()); for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(), col_shape_vec[j + 1] = filter_shape_vec[j + 2];
input_shape_vec.end()); col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
}
DDim col_shape(framework::make_ddim(col_shape_vec)); DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim col_matrix_shape = DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
Tensor col; Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
...@@ -136,7 +132,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -136,7 +132,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
input_batch, false, static_cast<T>(1.0), input_batch, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0)); &col_matrix, static_cast<T>(0.0));
if (filter_shape_vec.size() == 2) { if (data_dim == 2U) {
// col2im: col_matrix -> dy // col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w) // from (c * k_h * k_w, h * w) to (c, o_h, o_w)
col2im(context.device_context(), col, col2im(context.device_context(), col,
...@@ -144,7 +140,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -144,7 +140,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&output_batch); &output_batch);
} else if (filter_shape_vec.size() == 3) { } else if (data_dim == 3U) {
// col2vol: col_matrix -> dy // col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
col2vol(context.device_context(), col, dilations, strides, paddings, col2vol(context.device_context(), col, dilations, strides, paddings,
...@@ -176,30 +172,26 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -176,30 +172,26 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w} // input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims()); std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2); // filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims()); std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
filter_shape_vec.erase(filter_shape_vec.begin(),
filter_shape_vec.begin() + 2);
// use col_shape in the im2col and col2im (or vol2col and col2vol) // use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation // calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
std::vector<int64_t> col_shape_vec; size_t data_dim = filter_shape_vec.size() - 2;
col_shape_vec.push_back(output_grad->dims()[1]); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), col_shape_vec[0] = output_grad->dims()[1];
filter_shape_vec.end()); for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(), col_shape_vec[j + 1] = filter_shape_vec[j + 2];
input_shape_vec.end()); col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
}
DDim col_shape(framework::make_ddim(col_shape_vec)); DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim col_matrix_shape = DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
DDim output_shape = framework::slice_ddim(output_grad->dims(), 1, DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
...@@ -248,7 +240,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -248,7 +240,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
Tensor output_grad_batch = Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape); output_grad->Slice(i, i + 1).Resize(output_shape);
if (filter_shape_vec.size() == 2) { if (data_dim == 2U) {
// im2col: dy -> col matrix // im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w) // from (c, o_h, o_w) to (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch, im2col(context.device_context(), output_grad_batch,
...@@ -256,7 +248,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -256,7 +248,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&col); &col);
} else if (filter_shape_vec.size() == 3) { } else if (data_dim == 3U) {
// vol2col: dy -> col_matrix // vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
vol2col(context.device_context(), output_grad_batch, dilations, vol2col(context.device_context(), output_grad_batch, dilations,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册