未验证 提交 a7c4facb 编写于 作者: H haosicheng 提交者: GitHub

fix reduce mean grad bug *test=kunlun (#45511)

fix missing keep_dim variable

    fix missing grad check in unittest

    add new test case
上级 db235bf0
...@@ -90,6 +90,7 @@ class ReduceMeanGradXPUKernel : public framework::OpKernel<T> { ...@@ -90,6 +90,7 @@ class ReduceMeanGradXPUKernel : public framework::OpKernel<T> {
bool reduce_all = ctx.Attr<bool>("reduce_all"); bool reduce_all = ctx.Attr<bool>("reduce_all");
auto reduce_dims = ctx.Attr<std::vector<int>>("dim"); auto reduce_dims = ctx.Attr<std::vector<int>>("dim");
bool keep_dim = ctx.Attr<bool>("keep_dim");
std::vector<int> xdims; std::vector<int> xdims;
for (int i = 0; i < input->dims().size(); i++) { for (int i = 0; i < input->dims().size(); i++) {
...@@ -112,7 +113,13 @@ class ReduceMeanGradXPUKernel : public framework::OpKernel<T> { ...@@ -112,7 +113,13 @@ class ReduceMeanGradXPUKernel : public framework::OpKernel<T> {
d = d + xdims.size(); d = d + xdims.size();
} }
reduce_numel *= xdims[d]; reduce_numel *= xdims[d];
ydims.insert(ydims.begin() + d, 1); }
if (keep_dim != true) {
sort(reduce_dims.begin(), reduce_dims.end());
for (auto& d : reduce_dims) {
ydims.insert(ydims.begin() + d, 1);
}
} }
float val = 1.0f / static_cast<float>(reduce_numel); float val = 1.0f / static_cast<float>(reduce_numel);
......
...@@ -103,6 +103,9 @@ class XPUTestReduce(XPUOpTestWrapper): ...@@ -103,6 +103,9 @@ class XPUTestReduce(XPUOpTestWrapper):
# def test_check_grad(self): # def test_check_grad(self):
# self.check_output_with_place(self.place, ['X'], 'Out') # self.check_output_with_place(self.place, ['X'], 'Out')
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
class Test2DReduce0(Test1DReduce): class Test2DReduce0(Test1DReduce):
def setUp(self): def setUp(self):
...@@ -161,6 +164,18 @@ class XPUTestReduce(XPUOpTestWrapper): ...@@ -161,6 +164,18 @@ class XPUTestReduce(XPUOpTestWrapper):
'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim']))
} }
class Test6DReduce(Test1DReduce):
def setUp(self):
super().setUp()
self.attrs = {'dim': [1, -1], 'use_xpu': True}
self.inputs = {
'X': np.random.random((5, 6, 7, 8, 9, 10)).astype(self.dtype)
}
self.outputs = {
'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim']))
}
class TestKeepDimReduce(Test1DReduce): class TestKeepDimReduce(Test1DReduce):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册