提交 2b4edacc 编写于 作者: L luotao1

enhance the forward of concat op

上级 557be6fc
...@@ -48,16 +48,16 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { ...@@ -48,16 +48,16 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace()); auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation // computation
for (int k = 0; k < out_rows; ++k) { auto output_data = output->data<T>();
T* dst_ptr = output->data<T>() + k * out_cols;
int col_idx = 0; int col_idx = 0;
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
int col_len = input_cols[j]; int col_len = input_cols[j];
const T* src_prt = input[j].data<T>() + k * col_len; auto input_data = input[j].data<T>();
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt, for (int k = 0; k < out_rows; ++k) {
sizeof(T) * col_len); memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place,
col_idx += col_len; input_data + k * col_len, sizeof(T) * col_len);
} }
col_idx += col_len;
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册