未验证 提交 2c1d494e 编写于 作者: X Xinyu Chen 提交者: GitHub

elementwise: onednn: support zero dimension inputs (#51656)

上级 04025237
......@@ -91,8 +91,9 @@ class TransferLayoutFunctor {
phi::funcs::MatchShapeToLayout(&out_tensor, in_layout, out_layout);
phi::OneDNNContext::tls().set_cur_paddle_data_layout(in_layout);
}
auto out_tz = phi::vectorize<int64_t>(out_tensor.dims());
auto out_tz = out_tensor.dims().size() == 0
? std::vector<int64_t>{1}
: phi::vectorize(out_tensor.dims());
dnnl::memory::data_type in_type =
phi::funcs::ToOneDNNDataType(in_tensor.dtype());
......
......@@ -936,8 +936,13 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
// if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer
auto rankdiff = x->dims().size() - y->dims().size();
auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
: vectorize(out->dims());
auto dst_tz =
(out == nullptr)
? (rankdiff > 0 ? src_x_tz
: (y->dims().size() == 0 ? std::vector<int64_t>{1}
: src_x_tz))
: (out->dims().size() == 0 ? std::vector<int64_t>{1}
: vectorize(out->dims()));
auto src0_md = x->mem_desc();
auto src1_md = y->mem_desc();
......@@ -1074,7 +1079,8 @@ class BroadcastDataOneDNNHandler
float scale_y,
const std::vector<int64_t>& extended_x_dims)
: OneDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
const auto src0_tz = vectorize(out->dims());
const auto src0_tz = out->dims().size() == 0 ? std::vector<int64_t>{1}
: vectorize(out->dims());
const auto src0_md = dnnl::memory::desc(
src0_tz, OneDNNGetDataType<T>(), GetPlainOneDNNFormat(src0_tz.size()));
const auto src1_md = x->mem_desc().reshape(extended_x_dims);
......
......@@ -97,8 +97,10 @@ inline void BroadcastReduction(const Place& place,
{DNNL_ARG_DST, *dst_memory},
});
astream.wait();
grad_tensor->set_mem_desc(dst_memory->get_desc().reshape(
phi::vectorize<int64_t>(grad_tensor->dims())));
auto grad_shape = grad_tensor->dims().size() == 0
? std::vector<int64_t>{1}
: phi::vectorize<int64_t>(grad_tensor->dims());
grad_tensor->set_mem_desc(dst_memory->get_desc().reshape(grad_shape));
}
} // namespace funcs
......
......@@ -103,6 +103,27 @@ class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestMKLDNNElementwiseAddOp):
pass
class TestMKLDNNElementwiseAddOpZeroDim(TestMKLDNNElementwiseAddOp):
def init_input_output(self):
self.x = np.random.random((100,)).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.add(self.x, self.y)
class TestMKLDNNElementwiseAddOpZeroDim2(TestMKLDNNElementwiseAddOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.random.random((100,)).astype(self.dtype)
self.out = np.add(self.x, self.y)
class TestMKLDNNElementwiseAddOpZeroDim3(TestMKLDNNElementwiseAddOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.add(self.x, self.y)
''' INT8 Tests '''
......
......@@ -111,6 +111,45 @@ class TestMKLDNNElementwiseDivOp5(TestMKLDNNElementwiseDivOp):
pass
class TestMKLDNNElementwiseDivOpZeroDim(TestMKLDNNElementwiseDivOp):
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [100]).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.divide(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
class TestMKLDNNElementwiseDivOpZeroDim2(TestMKLDNNElementwiseDivOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [100]).astype(self.dtype)
self.out = np.divide(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
class TestMKLDNNElementwiseDivOpZeroDim3(TestMKLDNNElementwiseDivOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.divide(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
@OpTestTool.skip_if_not_cpu_bf16()
class TestBf16(TestMKLDNNElementwiseDivOp):
def setUp(self):
......
......@@ -76,6 +76,54 @@ class TestMKLDNNElementwiseMulOp5(TestMKLDNNElementwiseMulOp):
pass
class TestMKLDNNElementwiseMulOpZeroDim(TestMKLDNNElementwiseMulOp):
def init_input_output(self):
self.x = np.random.random((100,)).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_y(self):
pass
def test_check_grad_ingore_x(self):
pass
class TestMKLDNNElementwiseMulOpZeroDim2(TestMKLDNNElementwiseMulOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.random.random((100,)).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_y(self):
pass
def test_check_grad_ingore_x(self):
pass
class TestMKLDNNElementwiseMulOpZeroDim3(TestMKLDNNElementwiseMulOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_y(self):
pass
def test_check_grad_ingore_x(self):
pass
''' INT8 Tests '''
......
......@@ -133,6 +133,54 @@ class TestElementwiseSubOp_xsize_lessthan_ysize_sub(TestMKLDNNElementwiseSubOp):
self.axis = 2
class TestMKLDNNElementwiseSubOpZeroDim(TestMKLDNNElementwiseSubOp):
def init_input_output(self):
self.x = np.random.random((100,)).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.subtract(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
def test_check_grad_ignore_y(self):
pass
class TestMKLDNNElementwiseSubOpZeroDim2(TestMKLDNNElementwiseSubOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.random.random((100,)).astype(self.dtype)
self.out = np.subtract(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
def test_check_grad_ignore_y(self):
pass
class TestMKLDNNElementwiseSubOpZeroDim3(TestMKLDNNElementwiseSubOp):
def init_input_output(self):
self.x = np.array(3.0).astype(self.dtype)
self.y = np.array(3.0).astype(self.dtype)
self.out = np.subtract(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
def test_check_grad_ignore_y(self):
pass
@OpTestTool.skip_if_not_cpu_bf16()
class TestBf16(TestMKLDNNElementwiseSubOp):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册