未验证 提交 e19032fb 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #6743 from chengduoZH/profiling/02.recognize_digits

Refine elementwiseAdd and im2col
...@@ -103,11 +103,13 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> { ...@@ -103,11 +103,13 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() { MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
++j_; ++j_;
i_ = j_ / post_; if (UNLIKELY(j_ == post_)) {
if (UNLIKELY(i_ == n_)) { ++i_;
j_ = 0; j_ = 0;
if (UNLIKELY(i_ == n_)) {
i_ = 0; i_ = 0;
} }
}
return *this; return *this;
} }
...@@ -125,10 +127,10 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> { ...@@ -125,10 +127,10 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
private: private:
const T* ptr_; const T* ptr_;
int i_; int64_t i_;
int64_t j_; int64_t j_;
int64_t n_; int64_t n_;
int post_; int64_t post_;
}; };
#ifdef __NVCC__ #ifdef __NVCC__
......
...@@ -61,14 +61,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -61,14 +61,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col->data<T>(); T* col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height; int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) { for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w; int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
...@@ -130,16 +129,14 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -130,16 +129,14 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height; int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) { for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
if ((im_row_idx) >= 0 && (im_row_idx) < im_height && if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) { (im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_row_idx += c_im * im_height; im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] +=
im_data[im_row_idx * im_width + im_col_idx] +=
col_data[(c * col_height + h) * col_width + w]; col_data[(c * col_height + h) * col_width + w];
} }
} }
...@@ -199,12 +196,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -199,12 +196,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
for (int channel = 0; channel < im_channels; ++channel) { for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) { ++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset = int im_row_offset =
col_row_idx * stride[0] + filter_row_idx - padding[0]; col_row_idx * stride[0] + filter_row_idx - padding[0];
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_col_offset = int im_col_offset =
col_col_idx * stride[1] + filter_col_idx - padding[1]; col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset = int col_offset =
((((col_row_idx)*col_width + col_col_idx) * im_channels + ((((col_row_idx)*col_width + col_col_idx) * im_channels +
channel) * channel) *
...@@ -271,12 +269,13 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -271,12 +269,13 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
for (int channel = 0; channel < im_channels; ++channel) { for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) { ++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset = int im_row_offset =
col_row_idx * stride[0] + filter_row_idx - padding[0]; col_row_idx * stride[0] + filter_row_idx - padding[0];
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_col_offset = int im_col_offset =
col_col_idx * stride[1] + filter_col_idx - padding[1]; col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset = int col_offset =
(((col_row_idx * col_width + col_col_idx) * im_channels + (((col_row_idx * col_width + col_col_idx) * im_channels +
channel) * channel) *
...@@ -284,6 +283,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -284,6 +283,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
filter_row_idx) * filter_row_idx) *
filter_width + filter_width +
filter_col_idx; filter_col_idx;
if (im_row_offset >= 0 && im_row_offset < im_height && if (im_row_offset >= 0 && im_row_offset < im_height &&
im_col_offset >= 0 && im_col_offset < im_width) { im_col_offset >= 0 && im_col_offset < im_width) {
int im_offset = int im_offset =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册