未验证 提交 3409db95 编写于 作者: W wangchaochaohu 提交者: GitHub

fix reduce bug test=develop (#19971)

上级 3ea2b661
...@@ -197,6 +197,9 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -197,6 +197,9 @@ class ReduceOp : public framework::OperatorWithKernel {
remove(dims_vector.begin(), dims_vector.end(), kDelFlag), remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end()); dims_vector.end());
} }
if (!keep_dim && dims_vector.size() == 0) {
dims_vector.push_back(1);
}
auto out_dims = framework::make_ddim(dims_vector); auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (dims[0] != 0) { if (dims[0] != 0) {
......
...@@ -397,5 +397,19 @@ class TestReduceAll(OpTest): ...@@ -397,5 +397,19 @@ class TestReduceAll(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class Test1DReduceWithAxes1(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random(1).astype("float64")}
self.attrs = {'dim': [0], 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册