未验证 提交 a64a722a 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support ReshapeTransform/nll_loss/matmul support 0D (#53828)

上级 5745a63f
...@@ -666,10 +666,6 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -666,10 +666,6 @@ class MatMulOp : public framework::OperatorWithKernel {
dim_out.resize(dim_out.size() - 1); dim_out.resize(dim_out.size() - 1);
} }
if (dim_out.empty()) {
dim_out = {1};
}
phi::DDim ddim_out = phi::make_ddim(dim_out); phi::DDim ddim_out = phi::make_ddim(dim_out);
context->SetOutputDim("Out", ddim_out); context->SetOutputDim("Out", ddim_out);
......
...@@ -91,9 +91,6 @@ void MatMulV2Op::InferShape(framework::InferShapeContext* ctx) const { ...@@ -91,9 +91,6 @@ void MatMulV2Op::InferShape(framework::InferShapeContext* ctx) const {
if (!y_broadcasted) { if (!y_broadcasted) {
new_dims.push_back(N); new_dims.push_back(N);
} }
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
ctx->SetOutputDim("Out", phi::make_ddim(new_dims)); ctx->SetOutputDim("Out", phi::make_ddim(new_dims));
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
......
...@@ -334,8 +334,9 @@ void ExecuteMatMulV1(const ExecutionContext &ctx, ...@@ -334,8 +334,9 @@ void ExecuteMatMulV1(const ExecutionContext &ctx,
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
out->set_mem_desc( auto reshape_dims = out->dims().size() != 0 ? vectorize(out->dims())
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))); : std::vector<int64_t>{1};
out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims));
} }
template <typename T> template <typename T>
......
...@@ -146,8 +146,9 @@ inline void ExecuteMul(const OneDNNContext& dev_ctx, ...@@ -146,8 +146,9 @@ inline void ExecuteMul(const OneDNNContext& dev_ctx,
// This kernel is flattening dims so then we need to unflattened version // This kernel is flattening dims so then we need to unflattened version
// that should be set in out reshape require plain layout, but // that should be set in out reshape require plain layout, but
// MatmulV2MKLDNNHanlder enforces one so it should work // MatmulV2MKLDNNHanlder enforces one so it should work
out->set_mem_desc( auto reshape_dims = out->dims().size() != 0 ? vectorize(out->dims())
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))); : std::vector<int64_t>{1};
out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims));
} }
template <typename T, typename T_out> template <typename T, typename T_out>
...@@ -177,8 +178,9 @@ inline void ExecuteMatmul(const OneDNNContext& dev_ctx, ...@@ -177,8 +178,9 @@ inline void ExecuteMatmul(const OneDNNContext& dev_ctx,
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
out->set_mem_desc( auto reshape_dims = out->dims().size() != 0 ? vectorize(out->dims())
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))); : std::vector<int64_t>{1};
out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims));
} }
} // namespace funcs } // namespace funcs
......
...@@ -822,11 +822,11 @@ void NllLossGradInferMeta(const MetaTensor& x, ...@@ -822,11 +822,11 @@ void NllLossGradInferMeta(const MetaTensor& x,
if (check) { if (check) {
auto batch_size = x_dims[0]; auto batch_size = x_dims[0];
if (x_dims.size() == 2) { if (x_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dout_dims.size(),
1,
phi::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be 1"));
if (reduction == "none") { if (reduction == "none") {
PADDLE_ENFORCE_EQ(dout_dims.size(),
1,
phi::errors::InvalidArgument(
"The dimensions of Input(Out@Grad) must be 1"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dout_dims[0], dout_dims[0],
batch_size, batch_size,
...@@ -834,10 +834,10 @@ void NllLossGradInferMeta(const MetaTensor& x, ...@@ -834,10 +834,10 @@ void NllLossGradInferMeta(const MetaTensor& x,
"The unreduced size ofInput(Out@Grad) must be the " "The unreduced size ofInput(Out@Grad) must be the "
"same as batch_size.")); "same as batch_size."));
} else { } else {
PADDLE_ENFORCE_EQ(dout_dims[0], PADDLE_ENFORCE_EQ(dout_dims.size(),
1, 0,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The reduced size of Input(Out@Grad) must be 1")); "The dimensions of Input(Out@Grad) must be 0"));
} }
} else if (x_dims.size() == 4) { } else if (x_dims.size() == 4) {
if (reduction == "none") { if (reduction == "none") {
...@@ -855,10 +855,10 @@ void NllLossGradInferMeta(const MetaTensor& x, ...@@ -855,10 +855,10 @@ void NllLossGradInferMeta(const MetaTensor& x,
"The dimensions of Input(Out@Grad) must be match " "The dimensions of Input(Out@Grad) must be match "
"to Input(Label) dimensions.")); "to Input(Label) dimensions."));
} else { } else {
PADDLE_ENFORCE_EQ(dout_dims[0], PADDLE_ENFORCE_EQ(dout_dims.size(),
1, 0,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The reduced size of Input(Out@Grad) must be 1")); "The dimensions of Input(Out@Grad) must be 0"));
} }
} }
} }
......
...@@ -2057,9 +2057,6 @@ void MatmulInferMeta(const MetaTensor& x, ...@@ -2057,9 +2057,6 @@ void MatmulInferMeta(const MetaTensor& x,
if (!y_broadcasted) { if (!y_broadcasted) {
new_dims.push_back(N); new_dims.push_back(N);
} }
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto ddim_out = phi::make_ddim(new_dims); auto ddim_out = phi::make_ddim(new_dims);
......
...@@ -831,7 +831,7 @@ void NllLossRawInferMeta(const MetaTensor& input, ...@@ -831,7 +831,7 @@ void NllLossRawInferMeta(const MetaTensor& input,
if (reduction == "none") { if (reduction == "none") {
out->set_dims({x_dims[0]}); out->set_dims({x_dims[0]});
} else { } else {
out->set_dims({1}); out->set_dims(phi::make_ddim({}));
} }
} else if (x_dims.size() == 4) { } else if (x_dims.size() == 4) {
PADDLE_ENFORCE_EQ(label_dims.size(), PADDLE_ENFORCE_EQ(label_dims.size(),
...@@ -854,10 +854,10 @@ void NllLossRawInferMeta(const MetaTensor& input, ...@@ -854,10 +854,10 @@ void NllLossRawInferMeta(const MetaTensor& input,
if (reduction == "none") { if (reduction == "none") {
out->set_dims({x_dims[0], x_dims[2], x_dims[3]}); out->set_dims({x_dims[0], x_dims[2], x_dims[3]});
} else { } else {
out->set_dims({1}); out->set_dims(phi::make_ddim({}));
} }
} }
total_weight->set_dims({1}); total_weight->set_dims(phi::make_ddim({}));
out->set_dtype(input.dtype()); out->set_dtype(input.dtype());
total_weight->set_dtype(input.dtype()); total_weight->set_dtype(input.dtype());
} }
......
...@@ -126,7 +126,7 @@ void MatMulFunctionImplWithBlas( ...@@ -126,7 +126,7 @@ void MatMulFunctionImplWithBlas(
M, M,
N)); N));
VLOG(3) << "MatMul's case 1"; VLOG(3) << "MatMul's case 1";
Out->Resize({1}); Out->Resize(phi::make_ddim({}));
dev_ctx.template Alloc<T>(Out); dev_ctx.template Alloc<T>(Out);
blas.GEMM(CblasNoTrans, blas.GEMM(CblasNoTrans,
CblasTrans, CblasTrans,
...@@ -516,7 +516,7 @@ void MatMulFunctionImplWithCublasLt( ...@@ -516,7 +516,7 @@ void MatMulFunctionImplWithCublasLt(
N)); N));
// MatMul's case 0 => vector * vector // MatMul's case 0 => vector * vector
Out->Resize({1}); Out->Resize(phi::make_ddim({}));
dev_ctx.template Alloc<T>(Out); dev_ctx.template Alloc<T>(Out);
VLOG(3) << "MatMul with blaslt case 1"; VLOG(3) << "MatMul with blaslt case 1";
blaslt::Run(dev_ctx, blaslt::Run(dev_ctx,
......
...@@ -24,8 +24,7 @@ void NllLossKernel(const Context& dev_ctx, ...@@ -24,8 +24,7 @@ void NllLossKernel(const Context& dev_ctx,
const std::string& reduction, const std::string& reduction,
DenseTensor* out) { DenseTensor* out) {
DenseTensor total_weight; DenseTensor total_weight;
total_weight.set_meta( total_weight.set_meta(DenseTensorMeta(phi::CppTypeToDataType<T>::Type(), {}));
DenseTensorMeta(phi::CppTypeToDataType<T>::Type(), {1}));
dev_ctx.template Alloc<T>(total_weight); dev_ctx.template Alloc<T>(total_weight);
NllLossRawKernel(dev_ctx, NllLossRawKernel(dev_ctx,
input, input,
......
...@@ -856,8 +856,8 @@ class ReshapeTransform(Transform): ...@@ -856,8 +856,8 @@ class ReshapeTransform(Transform):
# [[[1., 1., 1.], # [[[1., 1., 1.],
# [1., 1., 1.]]]) # [1., 1., 1.]]])
print(reshape_transform.forward_log_det_jacobian(x)) print(reshape_transform.forward_log_det_jacobian(x))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, # Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.]) # 0.)
""" """
_type = Type.BIJECTION _type = Type.BIJECTION
...@@ -945,8 +945,7 @@ class ReshapeTransform(Transform): ...@@ -945,8 +945,7 @@ class ReshapeTransform(Transform):
) )
def _forward_log_det_jacobian(self, x): def _forward_log_det_jacobian(self, x):
# TODO(zhouwei): should not set shape to [1], which is [] shape = x.shape[: x.dim() - len(self._in_event_shape)]
shape = x.shape[: x.dim() - len(self._in_event_shape)] or [1]
return paddle.zeros(shape, dtype=x.dtype) return paddle.zeros(shape, dtype=x.dtype)
......
...@@ -77,12 +77,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): ...@@ -77,12 +77,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
Y = np.transpose(Y, tuple(dim)) Y = np.transpose(Y, tuple(dim))
Out = np.matmul(X, Y) Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out return Out
...@@ -177,9 +171,6 @@ class API_TestMm(unittest.TestCase): ...@@ -177,9 +171,6 @@ class API_TestMm(unittest.TestCase):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = paddle.static.data(name="x", shape=[2], dtype="float64") x = paddle.static.data(name="x", shape=[2], dtype="float64")
y = paddle.static.data(name='y', shape=[2], dtype='float64') y = paddle.static.data(name='y', shape=[2], dtype='float64')
res = paddle.static.data(
name="output", shape=[1], dtype="float64"
)
result = paddle.mm(x, y) result = paddle.mm(x, y)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
data1 = np.random.rand(2) data1 = np.random.rand(2)
...@@ -187,9 +178,7 @@ class API_TestMm(unittest.TestCase): ...@@ -187,9 +178,7 @@ class API_TestMm(unittest.TestCase):
np_res = exe.run( np_res = exe.run(
feed={'x': data1, 'y': data2}, fetch_list=[result] feed={'x': data1, 'y': data2}, fetch_list=[result]
) )
expected_result = np.matmul( expected_result = np.matmul(data1, data2)
data1.reshape(1, 2), data2.reshape(2, 1)
)
np.testing.assert_allclose( np.testing.assert_allclose(
np_res, np_res,
......
...@@ -102,12 +102,6 @@ def reference_matmul_mul_head( ...@@ -102,12 +102,6 @@ def reference_matmul_mul_head(
Y = transpose_mat(Y) Y = transpose_mat(Y)
Out = matmul_head(X, Y, head_number) Out = matmul_head(X, Y, head_number)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out return Out
...@@ -196,12 +190,6 @@ def reference_matmul_mul_head2( ...@@ -196,12 +190,6 @@ def reference_matmul_mul_head2(
Y = transpose_mat(Y) Y = transpose_mat(Y)
Out = matmul_head2(X, Y, head_number) Out = matmul_head2(X, Y, head_number)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out return Out
......
...@@ -45,12 +45,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): ...@@ -45,12 +45,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
Y = np.transpose(Y, tuple(dim)) Y = np.transpose(Y, tuple(dim))
Out = np.matmul(X, Y) Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float64")
return Out return Out
......
...@@ -2571,6 +2571,33 @@ class TestSundryAPI(unittest.TestCase): ...@@ -2571,6 +2571,33 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out2.shape, []) self.assertEqual(out2.shape, [])
self.assertEqual(out2, 2.5) self.assertEqual(out2, 2.5)
def test_matmul(self):
# 1) no transpose
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out1 = paddle.matmul(x, y)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(x.grad.shape, [10])
self.assertEqual(y.grad.shape, [10])
# 2) transpose x and y
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out2 = paddle.matmul(x, y, True, True)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(x.grad.shape, [10])
self.assertEqual(y.grad.shape, [10])
def test_linalg_slogdet(self): def test_linalg_slogdet(self):
# 2-D input # 2-D input
x = paddle.randn([3, 3]) x = paddle.randn([3, 3])
...@@ -4872,6 +4899,40 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -4872,6 +4899,40 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[1], 2.5) self.assertEqual(res[1], 2.5)
@prog_scope()
def test_matmul(self):
# 1) no transpose
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out = paddle.matmul(x, y)
paddle.static.append_backward(out)
self.assertEqual(out.shape, ())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name, y.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (10,))
self.assertEqual(res[2].shape, (10,))
# 2) transpose x and y
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out = paddle.matmul(x, y, True, True)
paddle.static.append_backward(out)
self.assertEqual(out.shape, ())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name, y.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (10,))
self.assertEqual(res[2].shape, (10,))
@prog_scope() @prog_scope()
def test_linalg_slogdet(self): def test_linalg_slogdet(self):
# 2-D input # 2-D input
...@@ -5994,6 +6055,31 @@ class TestLossAPI(unittest.TestCase): ...@@ -5994,6 +6055,31 @@ class TestLossAPI(unittest.TestCase):
self.assertEqual(loss.shape, []) self.assertEqual(loss.shape, [])
self.assertEqual(input.grad.shape, [3, 5]) self.assertEqual(input.grad.shape, [3, 5])
def test_nll_loss(self):
input = paddle.rand([5, 3])
input.stop_gradient = False
log_softmax = paddle.nn.LogSoftmax(axis=1)
log_out = log_softmax(input)
label = paddle.randint(0, 3, [5], "int64")
loss = paddle.nn.functional.nll_loss(log_out, label)
loss.backward()
self.assertEqual(loss.shape, [])
self.assertEqual(input.grad.shape, [5, 3])
input = paddle.rand([5, 3, 2, 4])
input.stop_gradient = False
log_softmax = paddle.nn.LogSoftmax(axis=1)
log_out = log_softmax(input)
label = paddle.randint(0, 3, [5, 2, 4], "int64")
loss = paddle.nn.functional.nll_loss(log_out, label)
loss.backward()
self.assertEqual(loss.shape, [])
self.assertEqual(input.grad.shape, [5, 3, 2, 4])
class TestLossAPIStatic(unittest.TestCase): class TestLossAPIStatic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -6060,6 +6146,40 @@ class TestLossAPIStatic(unittest.TestCase): ...@@ -6060,6 +6146,40 @@ class TestLossAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (3, 5)) self.assertEqual(res[1].shape, (3, 5))
@prog_scope()
def test_nll_loss(self):
input = paddle.rand([5, 3])
input.stop_gradient = False
log_softmax = paddle.nn.LogSoftmax(axis=1)
log_out = log_softmax(input)
label = paddle.randint(0, 3, shape=[5])
label.stop_gradient = False
loss = paddle.nn.functional.nll_loss(log_out, label)
paddle.static.append_backward(loss)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[loss, input.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (5, 3))
input = paddle.rand([5, 3, 2, 4])
input.stop_gradient = False
log_softmax = paddle.nn.LogSoftmax(axis=1)
log_out = log_softmax(input)
label = paddle.randint(0, 3, shape=[5, 2, 4])
label.stop_gradient = False
loss = paddle.nn.functional.nll_loss(log_out, label)
paddle.static.append_backward(loss)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[loss, input.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (5, 3, 2, 4))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -27,6 +27,7 @@ from ..fluid.data_feeder import ( ...@@ -27,6 +27,7 @@ from ..fluid.data_feeder import (
from ..framework import LayerHelper, in_dynamic_mode from ..framework import LayerHelper, in_dynamic_mode
from .creation import full from .creation import full
from .manipulation import cast from .manipulation import cast
from .math import _get_reduce_axis
__all__ = [] __all__ = []
...@@ -198,7 +199,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): ...@@ -198,7 +199,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
y = paddle.rand([10]) y = paddle.rand([10])
z = paddle.matmul(x, y) z = paddle.matmul(x, y)
print(z.shape) print(z.shape)
# (1,) # ()
# matrix * vector # matrix * vector
x = paddle.rand([10, 5]) x = paddle.rand([10, 5])
...@@ -459,12 +460,7 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -459,12 +460,7 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
reduce_out = helper.create_variable_for_type_inference( reduce_out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype() dtype=helper.input_dtype()
) )
reduce_all, axis = _get_reduce_axis(axis, x)
reduce_all = (
True if axis is None or axis == [] or asvector else False
)
axis = axis if axis is not None and axis != [] else [0]
reduce_type = ( reduce_type = (
'reduce_max' if porder == np.float64('inf') else 'reduce_min' 'reduce_max' if porder == np.float64('inf') else 'reduce_min'
) )
...@@ -516,6 +512,7 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -516,6 +512,7 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
sum_out = block.create_variable_for_type_inference( sum_out = block.create_variable_for_type_inference(
dtype=block.input_dtype() dtype=block.input_dtype()
) )
reduce_all, axis = _get_reduce_axis(axis, x)
block.append_op( block.append_op(
type='reduce_sum', type='reduce_sum',
inputs={'X': pow_out}, inputs={'X': pow_out},
...@@ -523,7 +520,7 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): ...@@ -523,7 +520,7 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None):
attrs={ attrs={
'dim': axis, 'dim': axis,
'keep_dim': keepdim, 'keep_dim': keepdim,
'reduce_all': True if axis is None else False, 'reduce_all': reduce_all,
}, },
) )
block.append_op( block.append_op(
...@@ -834,8 +831,6 @@ def cond(x, p=None, name=None): ...@@ -834,8 +831,6 @@ def cond(x, p=None, name=None):
if porder == -1 or porder == -np.inf: if porder == -1 or porder == -np.inf:
return _C_ops.min(sum_out, [-1], False) return _C_ops.min(sum_out, [-1], False)
else: else:
reduce_all = True if axis is None or axis == [] else False
axis = axis if axis is not None and axis != [] else [0]
block = LayerHelper('norm', **locals()) block = LayerHelper('norm', **locals())
abs_out = block.create_variable_for_type_inference( abs_out = block.create_variable_for_type_inference(
dtype=block.input_dtype() dtype=block.input_dtype()
...@@ -849,6 +844,8 @@ def cond(x, p=None, name=None): ...@@ -849,6 +844,8 @@ def cond(x, p=None, name=None):
block.append_op( block.append_op(
type='abs', inputs={'X': input}, outputs={'Out': abs_out} type='abs', inputs={'X': input}, outputs={'Out': abs_out}
) )
reduce_all, axis = _get_reduce_axis(axis, x)
block.append_op( block.append_op(
type='reduce_sum', type='reduce_sum',
inputs={'X': abs_out}, inputs={'X': abs_out},
...@@ -894,7 +891,6 @@ def cond(x, p=None, name=None): ...@@ -894,7 +891,6 @@ def cond(x, p=None, name=None):
sum_out_2 = _C_ops.sum(sum_out_1, axis, None, False) sum_out_2 = _C_ops.sum(sum_out_1, axis, None, False)
return _C_ops.pow(sum_out_2, float(1.0 / porder)) return _C_ops.pow(sum_out_2, float(1.0 / porder))
else: else:
reduce_all = True if axis is None or axis == [] else False
block = LayerHelper('norm', **locals()) block = LayerHelper('norm', **locals())
pow_out = block.create_variable_for_type_inference( pow_out = block.create_variable_for_type_inference(
dtype=block.input_dtype() dtype=block.input_dtype()
...@@ -914,6 +910,8 @@ def cond(x, p=None, name=None): ...@@ -914,6 +910,8 @@ def cond(x, p=None, name=None):
outputs={'Out': pow_out}, outputs={'Out': pow_out},
attrs={'factor': porder}, attrs={'factor': porder},
) )
reduce_all, axis = _get_reduce_axis(axis, x)
block.append_op( block.append_op(
type='reduce_sum', type='reduce_sum',
inputs={'X': pow_out}, inputs={'X': pow_out},
...@@ -960,7 +958,7 @@ def cond(x, p=None, name=None): ...@@ -960,7 +958,7 @@ def cond(x, p=None, name=None):
if porder == -2: if porder == -2:
return _C_ops.divide(min_out, max_out) return _C_ops.divide(min_out, max_out)
else: else:
reduce_all = True if axis is None or axis == [] else False reduce_all, axis = _get_reduce_axis(axis, x)
block = LayerHelper('norm', **locals()) block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference( out = block.create_variable_for_type_inference(
dtype=block.input_dtype() dtype=block.input_dtype()
......
...@@ -90,25 +90,16 @@ class TestJacobianNoBatch(unittest.TestCase): ...@@ -90,25 +90,16 @@ class TestJacobianNoBatch(unittest.TestCase):
) )
self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=None) self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=None)
if isinstance(self._actual, (tuple, list)): if isinstance(self._actual, (tuple, list)):
self._actual = paddle.concat([x[:] for x in self._actual], axis=1) self._actual = paddle.concat([x[:] for x in self._actual], axis=0)
self._expected = self._get_expected() self._expected = self._get_expected()
Index = collections.namedtuple('Index', ('type', 'value')) self.assertEqual(self._actual.numpy().dtype, self._expected.dtype)
indexes = ( np.testing.assert_allclose(
Index('all', (slice(0, None, None), slice(0, None, None))), self._actual.flatten(),
Index('row', (0, slice(0, None, None))), self._expected.flatten(),
Index('col', (slice(0, None, None), 0)), rtol=self._rtol,
Index('multi-row', (slice(0, 2, 1), slice(0, None, None))), atol=self._atol,
) )
self.assertEqual(self._actual[:].numpy().dtype, self._expected.dtype)
for index in indexes:
np.testing.assert_allclose(
self._actual.__getitem__(index.value),
self._expected.__getitem__(index.value),
rtol=self._rtol,
atol=self._atol,
err_msg=f'Testcase {index.type} index not passed, value is {index.value}',
)
def test_jacobian_attribute_operator(self): def test_jacobian_attribute_operator(self):
xs = ( xs = (
...@@ -121,25 +112,16 @@ class TestJacobianNoBatch(unittest.TestCase): ...@@ -121,25 +112,16 @@ class TestJacobianNoBatch(unittest.TestCase):
) )
self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=None) self._actual = paddle.autograd.jacobian(ys, xs, batch_axis=None)
if isinstance(self._actual, (tuple, list)): if isinstance(self._actual, (tuple, list)):
self._actual = paddle.concat([x[:] for x in self._actual], axis=1) self._actual = paddle.concat([x[:] for x in self._actual], axis=0)
self._expected = self._get_expected() self._expected = self._get_expected()
Index = collections.namedtuple('Index', ('type', 'value'))
indexes = (
Index('all', (slice(0, None, None), slice(0, None, None))),
Index('row', (0, slice(0, None, None))),
Index('col', (slice(0, None, None), 0)),
Index('multi-row', (slice(0, 2, 1), slice(0, None, None))),
)
self.assertEqual(self._actual.numpy().dtype, self._expected.dtype) self.assertEqual(self._actual.numpy().dtype, self._expected.dtype)
for index in indexes: np.testing.assert_allclose(
np.testing.assert_allclose( self._actual.flatten(),
self._actual.__getitem__(index.value), self._expected.flatten(),
self._expected.__getitem__(index.value), rtol=self._rtol,
rtol=self._rtol, atol=self._atol,
atol=self._atol, )
err_msg=f'Testcase {index.type} index not passed, value is {index.value}',
)
def _get_expected(self): def _get_expected(self):
xs = ( xs = (
......
...@@ -398,6 +398,7 @@ def _np_concat_matrix_sequence(src, src_format=MatrixFormat.NM): ...@@ -398,6 +398,7 @@ def _np_concat_matrix_sequence(src, src_format=MatrixFormat.NM):
return src return src
if not isinstance(src[0], typing.Sequence): if not isinstance(src[0], typing.Sequence):
src = [src] src = [src]
return concat_row(tuple(concat_col(xs) for xs in src)) return concat_row(tuple(concat_col(xs) for xs in src))
......
...@@ -1029,8 +1029,8 @@ class TestReshapeTransform(unittest.TestCase): ...@@ -1029,8 +1029,8 @@ class TestReshapeTransform(unittest.TestCase):
self.assertEqual(out.shape, [1, 1]) self.assertEqual(out.shape, [1, 1])
self.assertEqual(reshape.inverse(out).shape, []) self.assertEqual(reshape.inverse(out).shape, [])
# self.assertEqual(reshape.forward_log_det_jacobian(x).shape, []) self.assertEqual(reshape.forward_log_det_jacobian(x).shape, [])
# self.assertEqual(reshape.inverse_log_det_jacobian(out).shape, []) self.assertEqual(reshape.inverse_log_det_jacobian(out).shape, [])
self.assertEqual(reshape.forward_shape(x.shape), (1, 1)) self.assertEqual(reshape.forward_shape(x.shape), (1, 1))
self.assertEqual(reshape.inverse_shape(out.shape), ()) self.assertEqual(reshape.inverse_shape(out.shape), ())
......
...@@ -46,7 +46,7 @@ def reference_matmul(X, Y, transpose_x=False, transpose_y=False): ...@@ -46,7 +46,7 @@ def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
Y = np.transpose(Y, tuple(dim)) Y = np.transpose(Y, tuple(dim))
Out = np.atleast_1d(np.matmul(X, Y)) Out = np.matmul(X, Y)
return Out return Out
......
...@@ -56,12 +56,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): ...@@ -56,12 +56,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
y_dims = Y.shape y_dims = Y.shape
Y = Y.reshape((y_dims[0] * y_dims[1], y_dims[2])) Y = Y.reshape((y_dims[0] * y_dims[1], y_dims[2]))
Out = np.matmul(X, Y) Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float32")
return Out return Out
...@@ -141,9 +135,6 @@ class XPUTestMatmulOpErr(XPUOpTestWrapper): ...@@ -141,9 +135,6 @@ class XPUTestMatmulOpErr(XPUOpTestWrapper):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = paddle.static.data(name="x", shape=[2], dtype=self.in_type) x = paddle.static.data(name="x", shape=[2], dtype=self.in_type)
y = paddle.static.data(name='y', shape=[2], dtype=self.in_type) y = paddle.static.data(name='y', shape=[2], dtype=self.in_type)
res = paddle.static.data(
name="output", shape=[1], dtype=self.in_type
)
result = paddle.mm(x, y) result = paddle.mm(x, y)
exe = fluid.Executor(fluid.XPUPlace(0)) exe = fluid.Executor(fluid.XPUPlace(0))
data1 = np.random.rand(2).astype(self.in_type) data1 = np.random.rand(2).astype(self.in_type)
...@@ -151,9 +142,7 @@ class XPUTestMatmulOpErr(XPUOpTestWrapper): ...@@ -151,9 +142,7 @@ class XPUTestMatmulOpErr(XPUOpTestWrapper):
np_res = exe.run( np_res = exe.run(
feed={'x': data1, 'y': data2}, fetch_list=[result] feed={'x': data1, 'y': data2}, fetch_list=[result]
) )
expected_result = np.matmul( expected_result = np.matmul(data1, data2)
data1.reshape(1, 2), data2.reshape(2, 1)
)
np.testing.assert_allclose(np_res, expected_result, atol=1e-3) np.testing.assert_allclose(np_res, expected_result, atol=1e-3)
......
...@@ -46,12 +46,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): ...@@ -46,12 +46,6 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
Y = np.transpose(Y, tuple(dim)) Y = np.transpose(Y, tuple(dim))
Out = np.matmul(X, Y) Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float64")
return Out return Out
......
...@@ -2256,6 +2256,33 @@ class TestSundryAPI(unittest.TestCase): ...@@ -2256,6 +2256,33 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out2.shape, []) self.assertEqual(out2.shape, [])
self.assertEqual(out2, 2.5) self.assertEqual(out2, 2.5)
def test_matmul(self):
# 1) no transpose
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out1 = paddle.matmul(x, y)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(x.grad.shape, [10])
self.assertEqual(y.grad.shape, [10])
# 2) transpose x and y
x = paddle.randn([10])
x.stop_gradient = False
y = paddle.randn([10])
y.stop_gradient = False
out2 = paddle.matmul(x, y, True, True)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(x.grad.shape, [10])
self.assertEqual(y.grad.shape, [10])
def test_linalg_slogdet(self): def test_linalg_slogdet(self):
# 2-D input # 2-D input
x = paddle.randn([3, 3]) x = paddle.randn([3, 3])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册