提交 2b4edacc 编写于 作者: L luotao1

enhance the forward of concat op

上级 557be6fc
......@@ -48,16 +48,16 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation
for (int k = 0; k < out_rows; ++k) {
T* dst_ptr = output->data<T>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
const T* src_prt = input[j].data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt,
sizeof(T) * col_len);
col_idx += col_len;
auto output_data = output->data<T>();
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = input[j].data<T>();
for (int k = 0; k < out_rows; ++k) {
memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place,
input_data + k * col_len, sizeof(T) * col_len);
}
col_idx += col_len;
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册