未验证 提交 849d937b 编写于 作者: A Aganlengzi 提交者: GitHub

[fix] addmm supports 1-d input (#42959)

* addmm supports 1-d input

* fix coverage

* fix

* more ut
上级 114a5d21
...@@ -113,23 +113,23 @@ void AddmmInferMeta(const MetaTensor& input, ...@@ -113,23 +113,23 @@ void AddmmInferMeta(const MetaTensor& input,
"if you put exe.run(startup_program) " "if you put exe.run(startup_program) "
"after optimizer.minimize function.")); "after optimizer.minimize function."));
// dim check // dim check
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(ndim_input == 2 || ndim_input == 1,
ndim_input, true,
2, errors::InvalidArgument(
errors::InvalidArgument("The input tensor input's dimension must be 2. " "The input tensor input's dimension must be 2 or 1. "
"But received input's dimension = [%s].", "But received input's dimension = [%d].",
ndim_input)); ndim_input));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ndim_x, ndim_x,
2, 2,
errors::InvalidArgument("The input tensor x's dimension must be 2. " errors::InvalidArgument("The input tensor x's dimension must be 2. "
"But received x's dimension = [%s].", "But received x's dimension = [%d].",
ndim_x)); ndim_x));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ndim_y, ndim_y,
2, 2,
errors::InvalidArgument("The input tensor y's dimension must be 2. " errors::InvalidArgument("The input tensor y's dimension must be 2. "
"But received y's dimension = [%s].", "But received y's dimension = [%d].",
ndim_y)); ndim_y));
std::vector<int64_t> output_dims; std::vector<int64_t> output_dims;
......
...@@ -44,6 +44,10 @@ void AddmmGradKernel(const Context& dev_ctx, ...@@ -44,6 +44,10 @@ void AddmmGradKernel(const Context& dev_ctx,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* y_grad) { DenseTensor* y_grad) {
auto in_dims = input.dims(); auto in_dims = input.dims();
if (input.dims().size() == 1) {
in_dims = {1, input.dims()[0]};
input_grad->Resize(in_dims);
}
int total_elems = 0; int total_elems = 0;
VLOG(3) << "alpha: " << alpha << " beta: " << beta; VLOG(3) << "alpha: " << alpha << " beta: " << beta;
...@@ -85,6 +89,10 @@ void AddmmGradKernel(const Context& dev_ctx, ...@@ -85,6 +89,10 @@ void AddmmGradKernel(const Context& dev_ctx,
} }
blas.SCAL(total_elems, beta, input_grad->data<T>()); blas.SCAL(total_elems, beta, input_grad->data<T>());
if (input.dims().size() == 1) {
input_grad->Resize(input.dims());
}
} }
if (x_grad) { if (x_grad) {
dev_ctx.template Alloc<T>(x_grad); dev_ctx.template Alloc<T>(x_grad);
......
...@@ -44,6 +44,12 @@ void AddmmKernel(const Context& dev_ctx, ...@@ -44,6 +44,12 @@ void AddmmKernel(const Context& dev_ctx,
auto x_dims = x.dims(); auto x_dims = x.dims();
auto y_dims = y.dims(); auto y_dims = y.dims();
DenseTensor input_2d(input);
if (input.dims().size() == 1) {
input_dims = {1, input.dims()[0]};
input_2d.Resize(input_dims);
}
// broadcast mode check // broadcast mode check
if (x_dims[0] != input_dims[0]) { if (x_dims[0] != input_dims[0]) {
PADDLE_ENFORCE_EQ(input_dims[0], PADDLE_ENFORCE_EQ(input_dims[0],
...@@ -97,7 +103,8 @@ void AddmmKernel(const Context& dev_ctx, ...@@ -97,7 +103,8 @@ void AddmmKernel(const Context& dev_ctx,
bcast_dims[1] = y_dims[1] / input_dims[1]; bcast_dims[1] = y_dims[1] / input_dims[1];
VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]"; VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]";
// broadcast using eigen // broadcast using eigen
auto eigen_input = PhiEigenTensor<T, 2>::From(input); const DenseTensor& const_ref_input = input_2d;
auto eigen_input = PhiEigenTensor<T, 2>::From(const_ref_input);
auto eigen_out = PhiEigenTensor<T, 2>::From(*out); auto eigen_out = PhiEigenTensor<T, 2>::From(*out);
auto& place = *dev_ctx.eigen_device(); auto& place = *dev_ctx.eigen_device();
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval( funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
......
...@@ -221,7 +221,44 @@ class TestAddMMOp3(OpTest): ...@@ -221,7 +221,44 @@ class TestAddMMOp3(OpTest):
self.check_grad(['Input'], 'Out', no_grad_set=None) self.check_grad(['Input'], 'Out', no_grad_set=None)
class TestAddMMOp4(unittest.TestCase): class TestAddMMOp4(OpTest):
# test broadcast
def setUp(self):
self.op_type = "addmm"
self.dtype = np.float64
self.init_dtype_type()
self.inputs = {
'Input': np.random.random((100)).astype(self.dtype),
'X': np.random.random((20, 10)).astype(self.dtype),
'Y': np.random.random((10, 100)).astype(self.dtype),
}
self.attrs = {
'Alpha': 0.5,
'Beta': 2.0,
}
self.outputs = {'Out': self.attrs['Beta'] * self.inputs['Input'] + \
self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])}
def init_dtype_type(self):
pass
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input', 'X', 'Y'], 'Out')
def test_check_grad_x(self):
self.check_grad(['X'], 'Out', no_grad_set=None)
def test_check_grad_y(self):
self.check_grad(['Y'], 'Out', no_grad_set=None)
def test_check_grad_input(self):
self.check_grad(['Input'], 'Out', no_grad_set=None)
class TestAddMMOp5(unittest.TestCase):
def test_api_with_dygraph(self): def test_api_with_dygraph(self):
np_input = np.random.random((20, 30)).astype(np.float32) np_input = np.random.random((20, 30)).astype(np.float32)
np_x = np.random.random((20, 6)).astype(np.float32) np_x = np.random.random((20, 6)).astype(np.float32)
...@@ -235,7 +272,6 @@ class TestAddMMOp4(unittest.TestCase): ...@@ -235,7 +272,6 @@ class TestAddMMOp4(unittest.TestCase):
assert np.allclose(np_input + np.dot(np_x, np_y), out.numpy()) assert np.allclose(np_input + np.dot(np_x, np_y), out.numpy())
'''
class TestAddMMAPI(unittest.TestCase): class TestAddMMAPI(unittest.TestCase):
def test_api_error(self): def test_api_error(self):
data_x = np.ones((2, 2)).astype(np.float32) data_x = np.ones((2, 2)).astype(np.float32)
...@@ -249,9 +285,106 @@ class TestAddMMAPI(unittest.TestCase): ...@@ -249,9 +285,106 @@ class TestAddMMAPI(unittest.TestCase):
x = paddle.to_tensor(data_x_wrong) x = paddle.to_tensor(data_x_wrong)
y = paddle.to_tensor(data_y) y = paddle.to_tensor(data_y)
input = paddle.to_tensor(data_input) input = paddle.to_tensor(data_input)
out = paddle.tensor.addmm( input=input, x=x, y=y, beta=0.5, alpha=5.0 ) out = paddle.tensor.addmm(
input=input, x=x, y=y, beta=0.5, alpha=5.0)
self.assertRaises(ValueError, test_error1) self.assertRaises(ValueError, test_error1)
'''
def test_error2():
data_x_wrong = np.ones((2)).astype(np.float32)
x = paddle.to_tensor(data_x_wrong)
y = paddle.to_tensor(data_y)
input = paddle.to_tensor(data_input)
out = paddle.tensor.addmm(
input=input, x=x, y=y, beta=0.5, alpha=5.0)
self.assertRaises(ValueError, test_error2)
def test_error3():
data_input_wrong = np.ones((2, 2, 2)).astype(np.float32)
x = paddle.to_tensor(data_x)
y = paddle.to_tensor(data_y)
input = paddle.to_tensor(data_input_wrong)
out = paddle.tensor.addmm(
input=input, x=x, y=y, beta=0.5, alpha=5.0)
self.assertRaises(ValueError, test_error3)
def test_error4():
data_input_wrong = np.ones((5)).astype(np.float32)
x = paddle.to_tensor(data_x)
y = paddle.to_tensor(data_y)
input = paddle.to_tensor(data_input_wrong)
out = paddle.tensor.addmm(
input=input, x=x, y=y, beta=0.5, alpha=5.0)
self.assertRaises(ValueError, test_error4)
paddle.enable_static()
def test_api_normal_1(self):
data_x = np.ones((2, 2)).astype(np.float32)
data_y = np.ones((2, 2)).astype(np.float32)
data_input = np.ones((2, 2)).astype(np.float32)
data_alpha = 0.1
data_beta = 1.0
paddle.disable_static()
x = paddle.to_tensor(data_x)
y = paddle.to_tensor(data_y)
input = paddle.to_tensor(data_input)
paddle_output = paddle.tensor.addmm(
input=input, x=x, y=y, beta=data_beta, alpha=data_alpha)
numpy_output = data_beta * data_input + data_alpha * np.dot(data_x,
data_y)
self.assertEqual(np.allclose(numpy_output, paddle_output.numpy()), True)
paddle.enable_static()
def test_api_normal_2(self):
data_x = np.ones((3, 10)).astype(np.float32)
data_y = np.ones((10, 3)).astype(np.float32)
data_input = np.ones((3)).astype(np.float32)
data_alpha = 0.1
data_beta = 1.0
paddle.disable_static()
x = paddle.to_tensor(data_x)
y = paddle.to_tensor(data_y)
input = paddle.to_tensor(data_input)
paddle_output = paddle.tensor.addmm(
input=input, x=x, y=y, beta=data_beta, alpha=data_alpha)
numpy_output = data_beta * data_input + data_alpha * np.dot(data_x,
data_y)
self.assertEqual(np.allclose(numpy_output, paddle_output.numpy()), True)
paddle.enable_static()
def test_api_normal_3(self):
data_x = np.ones((3, 10)).astype(np.float32)
data_y = np.ones((10, 3)).astype(np.float32)
data_input = np.ones((1)).astype(np.float32)
data_alpha = 0.1
data_beta = 1.0
paddle.disable_static()
x = paddle.to_tensor(data_x)
y = paddle.to_tensor(data_y)
input = paddle.to_tensor(data_input)
paddle_output = paddle.tensor.addmm(
input=input, x=x, y=y, beta=data_beta, alpha=data_alpha)
numpy_output = data_beta * data_input + data_alpha * np.dot(data_x,
data_y)
self.assertEqual(np.allclose(numpy_output, paddle_output.numpy()), True)
paddle.enable_static()
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
......
...@@ -1610,8 +1610,11 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): ...@@ -1610,8 +1610,11 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
input_shape = input.shape input_shape = input.shape
x_shape = x.shape x_shape = x.shape
y_shape = y.shape y_shape = y.shape
if not len(input_shape) == len(x_shape) == len(y_shape) == 2: if not len(x_shape) == len(y_shape) == 2:
raise ValueError("The dimention of input, x, y should be 2 but receive input's shape: {}, x's shape: {}, y's shape: {}".format(input_shape, x_shape, y_shape)) raise ValueError("The dimention of x, y should be 2 but receive x's shape: {}, y's shape: {}".format(x_shape, y_shape))
if x_shape[1] != y_shape[0]:
raise ValueError("The input Variable x's width must be equal with Variable y' height. But received x's shape = {}, y's shape = {}.".format(x_shape, y_shape))
if len(input_shape) == 2:
if input_shape[0] != x_shape[0]: if input_shape[0] != x_shape[0]:
if input_shape[0] != 1: if input_shape[0] != 1:
raise ValueError( "When x's dimension[0] is not equal with input's dimension[0], input's dimension[0] must be 1 but got {}".format(input_shape[0])) raise ValueError( "When x's dimension[0] is not equal with input's dimension[0], input's dimension[0] must be 1 but got {}".format(input_shape[0]))
...@@ -1620,10 +1623,11 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): ...@@ -1620,10 +1623,11 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
if input_shape[1] != y_shape[1]: if input_shape[1] != y_shape[1]:
if input_shape[1] != 1: if input_shape[1] != 1:
raise ValueError( "When y's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {}".format(input_shape[1])) raise ValueError( "When y's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {}".format(input_shape[1]))
if input_shape[0] != x_shape[0] and input_shape[0] != 1: elif len(input_shape) == 1:
raise ValueError( "When x's dimension[0] is not equal with input's dimension[0], input's dimension[0] must be 1 but got {}".format(input_shape[0])) if input_shape[0] not in (y_shape[1], 1):
if x_shape[1] != y_shape[0]: raise ValueError("The input's shape: {} is not broadcastable with [x.shape[0], y.shape[1]]: [{},{}]".format(input_shape, x_shape[0], y_shape[1]))
raise ValueError("The input Variable x's width must be equal with Variable y' height. But received x's shape = {}, y's shape = {}.".format(x_shape, y_shape)) else:
raise ValueError("The dimention of input should be 2 or 1 but receive input's shape: {}".format(input_shape))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册