未验证 提交 a5bf156b 编写于 作者: Y ykkk2333 提交者: GitHub

[XPU] add einsum fill diagonal and diagonal kernels (#49465)

* migrate shaple sgd, split,sign xpu kernels to phi, test=kunlun

* fix dlrm throughput problem, test=kunlun

* add xpu einsum, fill_diagonal, and diagonal kernels, test=kunlun
上级 ee49994f
......@@ -133,10 +133,19 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"diagonal",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"dropout_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"dropout",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"einsum", XPUKernelSet({phi::DataType::FLOAT32})},
{"einsum_raw", XPUKernelSet({phi::DataType::FLOAT32})},
{"einsum_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_add_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"elementwise_add",
......@@ -210,6 +219,12 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"fill_diagonal_tensor",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"fill_constant",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diagonal_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void DiagonalKernel(const Context& dev_ctx,
const DenseTensor& x,
int offset,
int axis1,
int axis2,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
T* out_data = dev_ctx.template Alloc<T>(out);
std::vector<int64_t> xshape = phi::vectorize<int64_t>(x.dims());
std::vector<int64_t> yshape = phi::vectorize<int64_t>(out->dims());
int r = xpu::diagonal(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out_data),
xshape,
yshape,
axis1,
axis2,
offset);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "diagonal");
}
} // namespace phi
PD_REGISTER_KERNEL(diagonal,
XPU,
ALL_LAYOUT,
phi::DiagonalKernel,
float,
phi::dtype::float16,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/einsum_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_grad_impl.h"
PD_REGISTER_KERNEL(einsum_grad, XPU, ALL_LAYOUT, phi::EinsumGradKernel, float) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/einsum_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum_raw, XPU, ALL_LAYOUT, phi::EinsumKernelRaw, float) {}
PD_REGISTER_KERNEL(einsum, XPU, ALL_LAYOUT, phi::EinsumKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FillDiagonalTensorKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &y,
int64_t offset,
int dim1,
int dim2,
DenseTensor *out) {
using XPUType = typename XPUTypeTrait<T>::Type;
T *out_data = ctx.template Alloc<T>(out);
int r = xpu::copy(ctx.x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()),
reinterpret_cast<XPUType *>(out_data),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
std::vector<int64_t> xshape = phi::vectorize<int64_t>(x.dims());
std::vector<int64_t> yshape = phi::vectorize<int64_t>(y.dims());
r = xpu::fill_diagonal_tensor(ctx.x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()),
reinterpret_cast<const XPUType *>(y.data<T>()),
reinterpret_cast<XPUType *>(out_data),
xshape,
yshape,
dim1,
dim2,
offset);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fill_diagonal_tensor");
}
} // namespace phi
PD_REGISTER_KERNEL(fill_diagonal_tensor,
XPU,
ALL_LAYOUT,
phi::FillDiagonalTensorKernel,
float,
int64_t,
int,
phi::dtype::float16,
bool) {}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
sys.path.append("..")
from op_test import skip_check_grad_ci
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
paddle.enable_static()
class XPUTestDiagonalOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'diagonal'
self.use_dynamic_create_class = False
@skip_check_grad_ci(
reason="xpu fill_diagonal_tensor is not implemented yet"
)
class TestDiagonalOp(XPUOpTest):
def setUp(self):
self.op_type = "diagonal"
self.python_api = paddle.diagonal
self.dtype = self.in_type
self.init_config()
self.outputs = {'Out': self.target}
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0))
def init_config(self):
self.case = np.random.randn(10, 5, 2).astype(self.dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': 0, 'axis1': 0, 'axis2': 1}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
class TestDiagonalOpCase1(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(4, 2, 4, 4).astype(self.dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': -2, 'axis1': 3, 'axis2': 0}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
class TestDiagonalOpCase2(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(100, 100).astype(self.dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': 0, 'axis1': 0, 'axis2': 1}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
class TestDiagonalOpCase3(TestDiagonalOp):
def init_config(self):
self.case = np.random.randint(0, 2, (4, 2, 4, 4)).astype('bool')
self.inputs = {'Input': self.case}
self.attrs = {'offset': -2, 'axis1': 3, 'axis2': 0}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
def test_check_grad(self):
pass
class TestDiagonalOpCase4(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(100, 100).astype(self.dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': 1, 'axis1': 1, 'axis2': 0}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
def test_check_grad(self):
pass
class TestDiagonalOpCase5(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(4, 2, 4, 4).astype(self.dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': -2, 'axis1': 0, 'axis2': 3}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
class TestDiagonalAPI(unittest.TestCase):
def setUp(self):
self.shape = [10, 3, 4]
self.x = np.random.random((10, 3, 4)).astype(np.float32)
self.place = paddle.XPUPlace(0)
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.shape)
out = paddle.diagonal(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x}, fetch_list=[out])
out_ref = np.diagonal(self.x)
for out in res:
np.testing.assert_allclose(out, out_ref, rtol=1e-08)
def test_api_dygraph(self):
paddle.disable_static(self.place)
x_tensor = paddle.to_tensor(self.x)
out = paddle.diagonal(x_tensor)
out_ref = np.diagonal(self.x)
np.testing.assert_allclose(out.numpy(), out_ref, rtol=1e-08)
paddle.enable_static()
def test_api_eager(self):
paddle.disable_static(self.place)
x_tensor = paddle.to_tensor(self.x)
out = paddle.diagonal(x_tensor)
out2 = paddle.diagonal(x_tensor, offset=0, axis1=2, axis2=1)
out3 = paddle.diagonal(x_tensor, offset=1, axis1=0, axis2=1)
out4 = paddle.diagonal(x_tensor, offset=0, axis1=1, axis2=2)
out_ref = np.diagonal(self.x)
np.testing.assert_allclose(out.numpy(), out_ref, rtol=1e-08)
out2_ref = np.diagonal(self.x, offset=0, axis1=2, axis2=1)
np.testing.assert_allclose(out2.numpy(), out2_ref, rtol=1e-08)
out3_ref = np.diagonal(self.x, offset=1, axis1=0, axis2=1)
np.testing.assert_allclose(out3.numpy(), out3_ref, rtol=1e-08)
out4_ref = np.diagonal(self.x, offset=0, axis1=1, axis2=2)
np.testing.assert_allclose(out4.numpy(), out4_ref, rtol=1e-08)
paddle.enable_static()
support_types = get_xpu_op_support_types('diagonal')
for stype in support_types:
create_test_class(globals(), XPUTestDiagonalOp, stype)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
paddle.enable_static()
class XPUTestEinsumOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'einsum'
self.use_dynamic_create_class = False
class TestEinsumBinary(XPUOpTest):
def setUp(self):
self.op_type = "einsum"
self.disable = False
self.types = [self.in_type, self.in_type]
self.set_mandatory()
self.init_input()
np.random.seed(123)
out = np.einsum(self.equation, *self.inputs)
self.operands = []
for idx, inp in enumerate(self.inputs):
self.operands.append(("x" + str(idx), inp))
self.inputs = {"Operands": self.operands}
self.attrs = {"equation": self.equation}
self.outputs = {
'Out': out,
"InnerCache": [
('cache_' + str(i), np.array([1.0]))
for i in range(len(self.operands))
],
"XShape": [
('xshape_' + str(i), np.array([1.0]))
for i in range(len(self.operands))
],
}
def init_input(self):
self.inputs = []
for t, s in zip(self.types, self.shapes):
self.inputs.append(np.random.random(s).astype(t))
def set_mandatory(self):
self.shapes = [(10, 10, 20), (20, 6)]
self.equation = "mij,jk->ki"
def test_check_output(self):
if not self.disable:
self.check_output_with_place(
paddle.XPUPlace(0),
no_check_set=["InnerCache", "XShape"],
atol=5e-3,
)
def test_grad(self):
if not self.disable:
self.check_grad_with_place(
paddle.XPUPlace(0), [op[0] for op in self.operands], ["Out"]
)
class TestEinsum1(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(20, 3, 3), (20, 3, 3)]
self.equation = "mij,mjk->mik"
class TestEinsum2(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(20, 3, 3), (20, 3, 3)]
self.equation = "mij,mjk->ikm"
class TestEinsum3(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(10, 10), (10, 10)]
self.equation = "ij,jk->ik" # }}}
class TestEinsumWithReduction(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(10, 3, 5), (5, 30)]
self.equation = "ijk,kl->jl"
class TestEinsumWithReduction1(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(10, 3, 3, 5), (10, 5, 10, 10)]
self.equation = "mijk,mklh->ljm"
class TestEinsumWithUnary(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(10, 10, 3, 5)]
self.equation = "mijk->mi"
class TestEinsumWithUnary1(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(5, 10, 3, 3), (3, 6, 3, 10)]
self.equation = "imjl,jklm->imk"
class TestEinsumWithBroadcast1(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(5, 10, 3, 3)]
self.equation = "i...->..."
class TestEinsumWithBroadcast2(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(10, 11), (3, 4, 5, 10)]
self.equation = "...ij,...i->j..."
class TestEinsumWithBroadcast4(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(10, 3, 2, 3, 4), (12, 10)]
self.equation = "a...d,...cb->...abcd"
class TestEinsumWithBroadcast5(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(3, 2, 2, 10), (10, 3, 2, 2)]
self.equation = "...a,a...->..."
class TestEinsumWithBroadcast6(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(100), (100)]
self.equation = "i,i->"
class TestEinsumWithDiagonal(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(10, 10)]
self.equation = "ii->"
class TestEinsumWithDiagonal2(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(10, 3, 10)]
self.equation = "iji->j"
class TestEinsumWithDiagonal3(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(5, 3, 2, 1, 4, 5)]
self.equation = "a...a->..."
class TestEinsumWithDiagonal4(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(5, 3, 2, 1, 4, 5)]
self.equation = "a...a->a..."
class TestEinsumWithDiagonal5(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(8, 8, 8)]
self.equation = "aaa->a"
class TestEinsumWithDiagonal6(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(3, 5, 7, 3), (5, 7, 5, 7)]
self.equation = "ijki,jkjk->ik"
class TestEinsumWithDiagonal8(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(3, 5, 7, 3), (5, 7, 5, 7)]
self.equation = "ijki,jkjk->"
support_types = get_xpu_op_support_types('einsum')
for stype in support_types:
create_test_class(globals(), XPUTestEinsumOp, stype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
sys.path.append("..")
from op_test import skip_check_grad_ci
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
paddle.enable_static()
def fill_diagonal_ndarray(x, value, offset=0, dim1=0, dim2=1):
"""Fill value into the diagonal of x that offset is ${offset} and the coordinate system is (dim1, dim2)."""
strides = x.strides
shape = x.shape
if dim1 > dim2:
dim1, dim2 = dim2, dim1
assert 0 <= dim1 < dim2 <= 2
assert len(x.shape) == 3
dim_sum = dim1 + dim2
dim3 = len(x.shape) - dim_sum
if offset >= 0:
diagdim = min(shape[dim1], shape[dim2] - offset)
diagonal = np.lib.stride_tricks.as_strided(
x[:, offset:] if dim_sum == 1 else x[:, :, offset:],
shape=(shape[dim3], diagdim),
strides=(strides[dim3], strides[dim1] + strides[dim2]),
)
else:
diagdim = min(shape[dim2], shape[dim1] + offset)
diagonal = np.lib.stride_tricks.as_strided(
x[-offset:, :] if dim_sum in [1, 2] else x[:, -offset:],
shape=(shape[dim3], diagdim),
strides=(strides[dim3], strides[dim1] + strides[dim2]),
)
diagonal[...] = value
return x
def fill_gt(x, y, offset, dim1, dim2):
if dim1 > dim2:
dim1, dim2 = dim2, dim1
offset = -offset
xshape = x.shape
yshape = y.shape
if len(xshape) != 3:
perm_list = []
unperm_list = [0] * len(xshape)
idx = 0
for i in range(len(xshape)):
if i != dim1 and i != dim2:
perm_list.append(i)
unperm_list[i] = idx
idx += 1
perm_list += [dim1, dim2]
unperm_list[dim1] = idx
unperm_list[dim2] = idx + 1
x = np.transpose(x, perm_list)
y = y.reshape(-1, yshape[-1])
nxshape = x.shape
x = x.reshape((-1, xshape[dim1], xshape[dim2]))
out = fill_diagonal_ndarray(x, y, offset, 1, 2)
if len(xshape) != 3:
out = out.reshape(nxshape)
out = np.transpose(out, unperm_list)
return out
class XPUTestFillDiagTensorOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'fill_diagonal_tensor'
self.use_dynamic_create_class = False
@skip_check_grad_ci(
reason="xpu fill_diagonal_tensor is not implemented yet"
)
class TensorFillDiagTensor_Test(XPUOpTest):
def setUp(self):
self.op_type = "fill_diagonal_tensor"
self.python_api = paddle.tensor.manipulation.fill_diagonal_tensor
self.init_kernel_type()
x = np.random.random((10, 10)).astype(self.dtype)
y = np.random.random((10,)).astype(self.dtype)
dim1 = 0
dim2 = 1
offset = 0
out = fill_gt(x, y, offset, dim1, dim2)
self.inputs = {"X": x, "Y": y}
self.outputs = {'Out': out}
self.attrs = {"offset": offset, "dim1": dim1, "dim2": dim2}
def init_kernel_type(self):
self.dtype = self.in_type
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0))
class TensorFillDiagTensor_Test2(TensorFillDiagTensor_Test):
def setUp(self):
self.op_type = "fill_diagonal_tensor"
self.python_api = paddle.tensor.manipulation.fill_diagonal_tensor
self.init_kernel_type()
x = np.random.random((2, 20, 25)).astype(self.dtype)
y = np.random.random((2, 20)).astype(self.dtype)
dim1 = 2
dim2 = 1
offset = -3
out = fill_gt(x, y, offset, dim1, dim2)
self.inputs = {"X": x, "Y": y}
self.outputs = {'Out': out}
self.attrs = {"offset": offset, "dim1": dim1, "dim2": dim2}
class TensorFillDiagTensor_Test3(TensorFillDiagTensor_Test):
def setUp(self):
self.op_type = "fill_diagonal_tensor"
self.python_api = paddle.tensor.manipulation.fill_diagonal_tensor
self.init_kernel_type()
x = np.random.random((2, 20, 20, 3)).astype(self.dtype)
y = np.random.random((2, 3, 18)).astype(self.dtype)
dim1 = 1
dim2 = 2
offset = 2
out = fill_gt(x, y, offset, dim1, dim2)
self.inputs = {"X": x, "Y": y}
self.outputs = {'Out': out}
self.attrs = {"offset": offset, "dim1": dim1, "dim2": dim2}
support_types = get_xpu_op_support_types('fill_diagonal_tensor')
for stype in support_types:
create_test_class(globals(), XPUTestFillDiagTensorOp, stype)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册