提交 762160bd 编写于 作者: Q qiaolongfei

fix concat grad kernel

上级 2074d369
...@@ -209,7 +209,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -209,7 +209,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
outputs_cols[0] = 0; outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) { for (int i = 0; i < o_num; ++i) {
int t_col = outputs->at(i)->numel() / out_row; int t_col = ref_inputs.at(i)->numel() / out_row;
if (sameShape) { if (sameShape) {
if (t_col != out0_col) sameShape = false; if (t_col != out0_col) sameShape = false;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册