未验证 提交 a64a722a 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support ReshapeTransform/nll_loss/matmul support 0D (#53828)

上级 5745a63f
......@@ -666,10 +666,6 @@ class MatMulOp : public framework::OperatorWithKernel {
dim_out.resize(dim_out.size() - 1);
}
if (dim_out.empty()) {
dim_out = {1};
}
phi::DDim ddim_out = phi::make_ddim(dim_out);
context->SetOutputDim("Out", ddim_out);
......
......@@ -91,9 +91,6 @@ void MatMulV2Op::InferShape(framework::InferShapeContext* ctx) const {
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
ctx->SetOutputDim("Out", phi::make_ddim(new_dims));
ctx->ShareLoD("X", "Out");
......
......@@ -334,8 +334,9 @@ void ExecuteMatMulV1(const ExecutionContext &ctx,
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
auto reshape_dims = out->dims().size() != 0 ? vectorize(out->dims())
: std::vector<int64_t>{1};
out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims));
}
template <typename T>
......
......@@ -146,8 +146,9 @@ inline void ExecuteMul(const OneDNNContext& dev_ctx,
// This kernel is flattening dims so then we need to unflattened version
// that should be set in out reshape require plain layout, but
// MatmulV2MKLDNNHanlder enforces one so it should work
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
auto reshape_dims = out->dims().size() != 0 ? vectorize(out->dims())
: std::vector<int64_t>{1};
out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims));
}
template <typename T, typename T_out>
......@@ -177,8 +178,9 @@ inline void ExecuteMatmul(const OneDNNContext& dev_ctx,
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
auto reshape_dims = out->dims().size() != 0 ? vectorize(out->dims())
: std::vector<int64_t>{1};
out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims));
}
} // namespace funcs
......
......@@ -822,11 +822,11 @@ void NllLossGradInferMeta(const MetaTensor& x,
if (check) {
auto batch_size = x_dims[0];
if (x_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dout_dims.size(),
1,
phi::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be 1"));
if (reduction == "none") {
PADDLE_ENFORCE_EQ(dout_dims.size(),
1,
phi::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be 1"));
PADDLE_ENFORCE_EQ(
dout_dims[0],
batch_size,
......@@ -834,10 +834,10 @@ void NllLossGradInferMeta(const MetaTensor& x,
"The unreduced size ofInput(Out@Grad) must be the "
"same as batch_size."));
} else {
PADDLE_ENFORCE_EQ(dout_dims[0],
1,
PADDLE_ENFORCE_EQ(dout_dims.size(),
0,
phi::errors::InvalidArgument(
"The reduced size of Input(Out@Grad) must be 1"));
"The dimensions of Input(Out@Grad) must be 0"));
}
} else if (x_dims.size() == 4) {
if (reduction == "none") {
......@@ -855,10 +855,10 @@ void NllLossGradInferMeta(const MetaTensor& x,
"The dimensions of Input(Out@Grad) must be match "
"to Input(Label) dimensions."));
} else {
PADDLE_ENFORCE_EQ(dout_dims[0],
1,
PADDLE_ENFORCE_EQ(dout_dims.size(),
0,
phi::errors::InvalidArgument(
"The reduced size of Input(Out@Grad) must be 1"));
"The dimensions of Input(Out@Grad) must be 0"));
}
}
}
......
......@@ -2057,9 +2057,6 @@ void MatmulInferMeta(const MetaTensor& x,
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto ddim_out = phi::make_ddim(new_dims);
......
......@@ -831,7 +831,7 @@ void NllLossRawInferMeta(const MetaTensor& input,
if (reduction == "none") {
out->set_dims({x_dims[0]});
} else {
out->set_dims({1});
out->set_dims(phi::make_ddim({}));
}
} else if (x_dims.size() == 4) {
PADDLE_ENFORCE_EQ(label_dims.size(),
......@@ -854,10 +854,10 @@ void NllLossRawInferMeta(const MetaTensor& input,
if (reduction == "none") {
out->set_dims({x_dims[0], x_dims[2], x_dims[3]});
} else {
out->set_dims({1});
out->set_dims(phi::make_ddim({}));
}
}
total_weight->set_dims({1});
total_weight->set_dims(phi::make_ddim({}));
out->set_dtype(input.dtype());
total_weight->set_dtype(input.dtype());
}
......
......@@ -126,7 +126,7 @@ void MatMulFunctionImplWithBlas(
M,
N));
VLOG(3) << "MatMul's case 1";
Out->Resize({1});
Out->Resize(phi::make_ddim({}));
dev_ctx.template Alloc<T>(Out);
blas.GEMM(CblasNoTrans,
CblasTrans,
......@@ -516,7 +516,7 @@ void MatMulFunctionImplWithCublasLt(
N));
// MatMul's case 0 => vector * vector
Out->Resize({1});
Out->Resize(phi::make_ddim({}));
dev_ctx.template Alloc<T>(Out);
VLOG(3) << "MatMul with blaslt case 1";
blaslt::Run(dev_ctx,
......
......@@ -24,8 +24,7 @@ void NllLossKernel(const Context& dev_ctx,
const std::string& reduction,
DenseTensor* out) {
DenseTensor total_weight;
total_weight.set_meta(
DenseTensorMeta(phi::CppTypeToDataType<T>::Type(), {1}));
total_weight.set_meta(DenseTensorMeta(phi::CppTypeToDataType<T>::Type(), {}));
dev_ctx.template Alloc<T>(total_weight);
NllLossRawKernel(dev_ctx,
input,
......
......@@ -856,8 +856,8 @@ class ReshapeTransform(Transform):
# [[[1., 1., 1.],
# [1., 1., 1.]]])
print(reshape_transform.forward_log_det_jacobian(x))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.])
# Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# 0.)
"""
_type = Type.BIJECTION
......@@ -945,8 +945,7 @@ class ReshapeTransform(Transform):
)
def _forward_log_det_jacobian(self, x):
# TODO(zhouwei): should not set shape to [1], which is []
shape = x.shape[: x.dim() - len(self._in_event_shape)] or [1]
shape = x.shape[: x.dim() - len(self._in_event_shape)]
return paddle.zeros(shape, dtype=x.dtype)
......
......@@ -77,12 +77,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
Y = np.transpose(Y, tuple(dim))
Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out
......@@ -177,9 +171,6 @@ class API_TestMm(unittest.TestCase):
with fluid.program_guard(fluid.Program()):
x = paddle.static.data(name="x", shape=[2], dtype="float64")
y = paddle.static.data(name='y', shape=[2], dtype='float64')
res = paddle.static.data(
name="output", shape=[1], dtype="float64"
)
result = paddle.mm(x, y)
exe = fluid.Executor(fluid.CPUPlace())
data1 = np.random.rand(2)
......@@ -187,9 +178,7 @@ class API_TestMm(unittest.TestCase):
np_res = exe.run(
feed={'x': data1, 'y': data2}, fetch_list=[result]
)
expected_result = np.matmul(
data1.reshape(1, 2), data2.reshape(2, 1)
)
expected_result = np.matmul(data1, data2)
np.testing.assert_allclose(
np_res,
......
......@@ -102,12 +102,6 @@ def reference_matmul_mul_head(
Y = transpose_mat(Y)
Out = matmul_head(X, Y, head_number)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out
......@@ -196,12 +190,6 @@ def reference_matmul_mul_head2(
Y = transpose_mat(Y)
Out = matmul_head2(X, Y, head_number)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out
......
......@@ -45,12 +45,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
Y = np.transpose(Y, tuple(dim))
Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float64")
return Out
......
......@@ -2571,6 +2571,33 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out2.shape, [])
self.assertEqual(out2, 2.5)
def test_matmul(self):
# 1) no transpose
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out1 = paddle.matmul(x, y)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(x.grad.shape, [10])
self.assertEqual(y.grad.shape, [10])
# 2) transpose x and y
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out2 = paddle.matmul(x, y, True, True)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(x.grad.shape, [10])
self.assertEqual(y.grad.shape, [10])
def test_linalg_slogdet(self):
# 2-D input
x = paddle.randn([3, 3])
......@@ -4872,6 +4899,40 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[1].shape, ())
self.assertEqual(res[1], 2.5)
@prog_scope()
def test_matmul(self):
# 1) no transpose
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out = paddle.matmul(x, y)
paddle.static.append_backward(out)
self.assertEqual(out.shape, ())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name, y.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (10,))
self.assertEqual(res[2].shape, (10,))
# 2) transpose x and y
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out = paddle.matmul(x, y, True, True)
paddle.static.append_backward(out)
self.assertEqual(out.shape, ())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name, y.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (10,))
self.assertEqual(res[2].shape, (10,))
@prog_scope()
def test_linalg_slogdet(self):
# 2-D input
......@@ -5994,6 +6055,31 @@ class TestLossAPI(unittest.TestCase):
self.assertEqual(loss.shape, [])
self.assertEqual(input.grad.shape, [3, 5])
def test_nll_loss(self):
input = paddle.rand([5, 3])
input.stop_gradient = False
log_softmax = paddle.nn.LogSoftmax(axis=1)
log_out = log_softmax(input)
label = paddle.randint(0, 3, [5], "int64")
loss = paddle.nn.functional.nll_loss(log_out, label)
loss.backward()
self.assertEqual(loss.shape, [])
self.assertEqual(input.grad.shape, [5, 3])
input = paddle.rand([5, 3, 2, 4])
input.stop_gradient = False
log_softmax = paddle.nn.LogSoftmax(axis=1)
log_out = log_softmax(input)
label = paddle.randint(0, 3, [5, 2, 4], "int64")
loss = paddle.nn.functional.nll_loss(log_out, label)
loss.backward()
self.assertEqual(loss.shape, [])
self.assertEqual(input.grad.shape, [5, 3, 2, 4])
class TestLossAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -6060,6 +6146,40 @@ class TestLossAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (3, 5))
@prog_scope()
def test_nll_loss(self):
input = paddle.rand([5, 3])
input.stop_gradient = False
log_softmax = paddle.nn.LogSoftmax(axis=1)
log_out = log_softmax(input)
label = paddle.randint(0, 3, shape=[5])
label.stop_gradient = False
loss = paddle.nn.functional.nll_loss(log_out, label)
paddle.static.append_backward(loss)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[loss, input.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (5, 3))
input = paddle.rand([5, 3, 2, 4])
input.stop_gradient = False
log_softmax = paddle.nn.LogSoftmax(axis=1)
log_out = log_softmax(input)
label = paddle.randint(0, 3, shape=[5, 2, 4])
label.stop_gradient = False
loss = paddle.nn.functional.nll_loss(log_out, label)
paddle.static.append_backward(loss)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[loss, input.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (5, 3, 2, 4))
if __name__ == "__main__":
unittest.main()
......@@ -27,6 +27,7 @@ from ..fluid.data_feeder import (
from ..framework import LayerHelper, in_dynamic_mode
from .creation import full
from .manipulation import cast
from .math import _get_reduce_axis
__all__ = []
......@@ -198,7 +199,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
y = paddle.rand([10])
z = paddle.matmul(x, y)
print(z.shape)
# (1,)
# ()
# matrix * vector
x = paddle.rand([10, 5])
......@@ -459,12 +460,7 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
reduce_out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype()
)
reduce_all = (
True if axis is None or axis == [] or asvector else False
)
axis = axis if axis is not None and axis != [] else [0]
reduce_all, axis = _get_reduce_axis(axis, x)
reduce_type = (
'reduce_max' if porder == np.float64('inf') else 'reduce_min'
)
......@@ -516,6 +512,7 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
sum_out = block.create_variable_for_type_inference(
dtype=block.input_dtype()
)
reduce_all, axis = _get_reduce_axis(axis, x)
block.append_op(
type='reduce_sum',
inputs={'X': pow_out},
......@@ -523,7 +520,7 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
attrs={
'dim': axis,
'keep_dim': keepdim,
'reduce_all': True if axis is None else False,
'reduce_all': reduce_all,
},
)
block.append_op(
......@@ -834,8 +831,6 @@ def cond(x, p=None, name=None):
if porder == -1 or porder == -np.inf:
return _C_ops.min(sum_out, [-1], False)
else:
reduce_all = True if axis is None or axis == [] else False
axis = axis if axis is not None and axis != [] else [0]
block = LayerHelper('norm', **locals())
abs_out = block.create_variable_for_type_inference(
dtype=block.input_dtype()
......@@ -849,6 +844,8 @@ def cond(x, p=None, name=None):
block.append_op(
type='abs', inputs={'X': input}, outputs={'Out': abs_out}
)
reduce_all, axis = _get_reduce_axis(axis, x)
block.append_op(
type='reduce_sum',
inputs={'X': abs_out},
......@@ -894,7 +891,6 @@ def cond(x, p=None, name=None):
sum_out_2 = _C_ops.sum(sum_out_1, axis, None, False)
return _C_ops.pow(sum_out_2, float(1.0 / porder))
else:
reduce_all = True if axis is None or axis == [] else False
block = LayerHelper('norm', **locals())
pow_out = block.create_variable_for_type_inference(
dtype=block.input_dtype()
......@@ -914,6 +910,8 @@ def cond(x, p=None, name=None):
outputs={'Out': pow_out},
attrs={'factor': porder},
)
reduce_all, axis = _get_reduce_axis(axis, x)
block.append_op(
type='reduce_sum',
inputs={'X': pow_out},
......@@ -960,7 +958,7 @@ def cond(x, p=None, name=None):
if porder == -2:
return _C_ops.divide(min_out, max_out)
else:
reduce_all = True if axis is None or axis == [] else False
reduce_all, axis = _get_reduce_axis(axis, x)
block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype()
......
......@@ -90,25 +90,16 @@ class TestJacobianNoBatch(unittest.TestCase):
)
self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=None)
if isinstance(self._actual, (tuple, list)):
self._actual = paddle.concat([x[:] for x in self._actual], axis=1)
self._actual = paddle.concat([x[:] for x in self._actual], axis=0)
self._expected = self._get_expected()
Index = collections.namedtuple('Index', ('type', 'value'))
indexes = (
Index('all', (slice(0, None, None), slice(0, None, None))),
Index('row', (0, slice(0, None, None))),
Index('col', (slice(0, None, None), 0)),
Index('multi-row', (slice(0, 2, 1), slice(0, None, None))),
self.assertEqual(self._actual.numpy().dtype, self._expected.dtype)
np.testing.assert_allclose(
self._actual.flatten(),
self._expected.flatten(),
rtol=self._rtol,
atol=self._atol,
)
self.assertEqual(self._actual[:].numpy().dtype, self._expected.dtype)
for index in indexes:
np.testing.assert_allclose(
self._actual.__getitem__(index.value),
self._expected.__getitem__(index.value),
rtol=self._rtol,
atol=self._atol,
err_msg=f'Testcase {index.type} index not passed, value is {index.value}',
)
def test_jacobian_attribute_operator(self):
xs = (
......@@ -121,25 +112,16 @@ class TestJacobianNoBatch(unittest.TestCase):
)
self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=None)
if isinstance(self._actual, (tuple, list)):
self._actual = paddle.concat([x[:] for x in self._actual], axis=1)
self._actual = paddle.concat([x[:] for x in self._actual], axis=0)
self._expected = self._get_expected()
Index = collections.namedtuple('Index', ('type', 'value'))
indexes = (
Index('all', (slice(0, None, None), slice(0, None, None))),
Index('row', (0, slice(0, None, None))),
Index('col', (slice(0, None, None), 0)),
Index('multi-row', (slice(0, 2, 1), slice(0, None, None))),
)
self.assertEqual(self._actual.numpy().dtype, self._expected.dtype)
for index in indexes:
np.testing.assert_allclose(
self._actual.__getitem__(index.value),
self._expected.__getitem__(index.value),
rtol=self._rtol,
atol=self._atol,
err_msg=f'Testcase {index.type} index not passed, value is {index.value}',
)
np.testing.assert_allclose(
self._actual.flatten(),
self._expected.flatten(),
rtol=self._rtol,
atol=self._atol,
)
def _get_expected(self):
xs = (
......
......@@ -398,6 +398,7 @@ def _np_concat_matrix_sequence(src, src_format=MatrixFormat.NM):
return src
if not isinstance(src[0], typing.Sequence):
src = [src]
return concat_row(tuple(concat_col(xs) for xs in src))
......
......@@ -1029,8 +1029,8 @@ class TestReshapeTransform(unittest.TestCase):
self.assertEqual(out.shape, [1, 1])
self.assertEqual(reshape.inverse(out).shape, [])
# self.assertEqual(reshape.forward_log_det_jacobian(x).shape, [])
# self.assertEqual(reshape.inverse_log_det_jacobian(out).shape, [])
self.assertEqual(reshape.forward_log_det_jacobian(x).shape, [])
self.assertEqual(reshape.inverse_log_det_jacobian(out).shape, [])
self.assertEqual(reshape.forward_shape(x.shape), (1, 1))
self.assertEqual(reshape.inverse_shape(out.shape), ())
......
......@@ -46,7 +46,7 @@ def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
Y = np.transpose(Y, tuple(dim))
Out = np.atleast_1d(np.matmul(X, Y))
Out = np.matmul(X, Y)
return Out
......
......@@ -56,12 +56,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
y_dims = Y.shape
Y = Y.reshape((y_dims[0] * y_dims[1], y_dims[2]))
Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out
......@@ -141,9 +135,6 @@ class XPUTestMatmulOpErr(XPUOpTestWrapper):
with fluid.program_guard(fluid.Program()):
x = paddle.static.data(name="x", shape=[2], dtype=self.in_type)
y = paddle.static.data(name='y', shape=[2], dtype=self.in_type)
res = paddle.static.data(
name="output", shape=[1], dtype=self.in_type
)
result = paddle.mm(x, y)
exe = fluid.Executor(fluid.XPUPlace(0))
data1 = np.random.rand(2).astype(self.in_type)
......@@ -151,9 +142,7 @@ class XPUTestMatmulOpErr(XPUOpTestWrapper):
np_res = exe.run(
feed={'x': data1, 'y': data2}, fetch_list=[result]
)
expected_result = np.matmul(
data1.reshape(1, 2), data2.reshape(2, 1)
)
expected_result = np.matmul(data1, data2)
np.testing.assert_allclose(np_res, expected_result, atol=1e-3)
......
......@@ -46,12 +46,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
Y = np.transpose(Y, tuple(dim))
Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float64")
return Out
......
......@@ -2256,6 +2256,33 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out2.shape, [])
self.assertEqual(out2, 2.5)
def test_matmul(self):
# 1) no transpose
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out1 = paddle.matmul(x, y)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(x.grad.shape, [10])
self.assertEqual(y.grad.shape, [10])
# 2) transpose x and y
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out2 = paddle.matmul(x, y, True, True)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(x.grad.shape, [10])
self.assertEqual(y.grad.shape, [10])
def test_linalg_slogdet(self):
# 2-D input
x = paddle.randn([3, 3])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册