未验证 提交 2c1d494e 编写于 作者: X Xinyu Chen 提交者: GitHub

elementwise: onednn: support zero dimension inputs (#51656)

上级 04025237
...@@ -91,8 +91,9 @@ class TransferLayoutFunctor { ...@@ -91,8 +91,9 @@ class TransferLayoutFunctor {
phi::funcs::MatchShapeToLayout(&out_tensor, in_layout, out_layout); phi::funcs::MatchShapeToLayout(&out_tensor, in_layout, out_layout);
phi::OneDNNContext::tls().set_cur_paddle_data_layout(in_layout); phi::OneDNNContext::tls().set_cur_paddle_data_layout(in_layout);
} }
auto out_tz = out_tensor.dims().size() == 0
auto out_tz = phi::vectorize<int64_t>(out_tensor.dims()); ? std::vector<int64_t>{1}
: phi::vectorize(out_tensor.dims());
dnnl::memory::data_type in_type = dnnl::memory::data_type in_type =
phi::funcs::ToOneDNNDataType(in_tensor.dtype()); phi::funcs::ToOneDNNDataType(in_tensor.dtype());
......
...@@ -936,8 +936,13 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> { ...@@ -936,8 +936,13 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
// if output tensor(z) is nullptr then we are computing into oneDNN // if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer // managed buffer
auto rankdiff = x->dims().size() - y->dims().size(); auto rankdiff = x->dims().size() - y->dims().size();
auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz) auto dst_tz =
: vectorize(out->dims()); (out == nullptr)
? (rankdiff > 0 ? src_x_tz
: (y->dims().size() == 0 ? std::vector<int64_t>{1}
: src_x_tz))
: (out->dims().size() == 0 ? std::vector<int64_t>{1}
: vectorize(out->dims()));
auto src0_md = x->mem_desc(); auto src0_md = x->mem_desc();
auto src1_md = y->mem_desc(); auto src1_md = y->mem_desc();
...@@ -1074,7 +1079,8 @@ class BroadcastDataOneDNNHandler ...@@ -1074,7 +1079,8 @@ class BroadcastDataOneDNNHandler
float scale_y, float scale_y,
const std::vector<int64_t>& extended_x_dims) const std::vector<int64_t>& extended_x_dims)
: OneDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) { : OneDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
const auto src0_tz = vectorize(out->dims()); const auto src0_tz = out->dims().size() == 0 ? std::vector<int64_t>{1}
: vectorize(out->dims());
const auto src0_md = dnnl::memory::desc( const auto src0_md = dnnl::memory::desc(
src0_tz, OneDNNGetDataType<T>(), GetPlainOneDNNFormat(src0_tz.size())); src0_tz, OneDNNGetDataType<T>(), GetPlainOneDNNFormat(src0_tz.size()));
const auto src1_md = x->mem_desc().reshape(extended_x_dims); const auto src1_md = x->mem_desc().reshape(extended_x_dims);
......
...@@ -97,8 +97,10 @@ inline void BroadcastReduction(const Place& place, ...@@ -97,8 +97,10 @@ inline void BroadcastReduction(const Place& place,
{DNNL_ARG_DST, *dst_memory}, {DNNL_ARG_DST, *dst_memory},
}); });
astream.wait(); astream.wait();
grad_tensor->set_mem_desc(dst_memory->get_desc().reshape( auto grad_shape = grad_tensor->dims().size() == 0
phi::vectorize<int64_t>(grad_tensor->dims()))); ? std::vector<int64_t>{1}
: phi::vectorize<int64_t>(grad_tensor->dims());
grad_tensor->set_mem_desc(dst_memory->get_desc().reshape(grad_shape));
} }
} // namespace funcs } // namespace funcs
......
...@@ -103,6 +103,27 @@ class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestMKLDNNElementwiseAddOp): ...@@ -103,6 +103,27 @@ class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestMKLDNNElementwiseAddOp):
pass pass
class TestMKLDNNElementwiseAddOpZeroDim(TestMKLDNNElementwiseAddOp):
def init_input_output(self):
self.x = np.random.random((100,)).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.add(self.x, self.y)
class TestMKLDNNElementwiseAddOpZeroDim2(TestMKLDNNElementwiseAddOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.random.random((100,)).astype(self.dtype)
self.out = np.add(self.x, self.y)
class TestMKLDNNElementwiseAddOpZeroDim3(TestMKLDNNElementwiseAddOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.add(self.x, self.y)
''' INT8 Tests ''' ''' INT8 Tests '''
......
...@@ -111,6 +111,45 @@ class TestMKLDNNElementwiseDivOp5(TestMKLDNNElementwiseDivOp): ...@@ -111,6 +111,45 @@ class TestMKLDNNElementwiseDivOp5(TestMKLDNNElementwiseDivOp):
pass pass
class TestMKLDNNElementwiseDivOpZeroDim(TestMKLDNNElementwiseDivOp):
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [100]).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.divide(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
class TestMKLDNNElementwiseDivOpZeroDim2(TestMKLDNNElementwiseDivOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [100]).astype(self.dtype)
self.out = np.divide(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
class TestMKLDNNElementwiseDivOpZeroDim3(TestMKLDNNElementwiseDivOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.divide(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
@OpTestTool.skip_if_not_cpu_bf16() @OpTestTool.skip_if_not_cpu_bf16()
class TestBf16(TestMKLDNNElementwiseDivOp): class TestBf16(TestMKLDNNElementwiseDivOp):
def setUp(self): def setUp(self):
......
...@@ -76,6 +76,54 @@ class TestMKLDNNElementwiseMulOp5(TestMKLDNNElementwiseMulOp): ...@@ -76,6 +76,54 @@ class TestMKLDNNElementwiseMulOp5(TestMKLDNNElementwiseMulOp):
pass pass
class TestMKLDNNElementwiseMulOpZeroDim(TestMKLDNNElementwiseMulOp):
def init_input_output(self):
self.x = np.random.random((100,)).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_y(self):
pass
def test_check_grad_ingore_x(self):
pass
class TestMKLDNNElementwiseMulOpZeroDim2(TestMKLDNNElementwiseMulOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.random.random((100,)).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_y(self):
pass
def test_check_grad_ingore_x(self):
pass
class TestMKLDNNElementwiseMulOpZeroDim3(TestMKLDNNElementwiseMulOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_y(self):
pass
def test_check_grad_ingore_x(self):
pass
''' INT8 Tests ''' ''' INT8 Tests '''
......
...@@ -133,6 +133,54 @@ class TestElementwiseSubOp_xsize_lessthan_ysize_sub(TestMKLDNNElementwiseSubOp): ...@@ -133,6 +133,54 @@ class TestElementwiseSubOp_xsize_lessthan_ysize_sub(TestMKLDNNElementwiseSubOp):
self.axis = 2 self.axis = 2
class TestMKLDNNElementwiseSubOpZeroDim(TestMKLDNNElementwiseSubOp):
def init_input_output(self):
self.x = np.random.random((100,)).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.subtract(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
def test_check_grad_ignore_y(self):
pass
class TestMKLDNNElementwiseSubOpZeroDim2(TestMKLDNNElementwiseSubOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.random.random((100,)).astype(self.dtype)
self.out = np.subtract(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
def test_check_grad_ignore_y(self):
pass
class TestMKLDNNElementwiseSubOpZeroDim3(TestMKLDNNElementwiseSubOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.subtract(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
def test_check_grad_ignore_y(self):
pass
@OpTestTool.skip_if_not_cpu_bf16() @OpTestTool.skip_if_not_cpu_bf16()
class TestBf16(TestMKLDNNElementwiseSubOp): class TestBf16(TestMKLDNNElementwiseSubOp):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册