未验证 提交 2a992178 编写于 作者: H haosicheng 提交者: GitHub

fix reduce mean grad bug *test=kunlun (#45401)

* add temporal shift and grad *test=kunlun

* fix reduce mean grad bug *test=kunlun
上级 91298884
...@@ -112,6 +112,7 @@ class ReduceMeanGradXPUKernel : public framework::OpKernel<T> { ...@@ -112,6 +112,7 @@ class ReduceMeanGradXPUKernel : public framework::OpKernel<T> {
d = d + xdims.size(); d = d + xdims.size();
} }
reduce_numel *= xdims[d]; reduce_numel *= xdims[d];
ydims.insert(ydims.begin() + d, 1);
} }
float val = 1.0f / static_cast<float>(reduce_numel); float val = 1.0f / static_cast<float>(reduce_numel);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册