未验证 提交 ad7ac4c6 编写于 作者: Q Qinghe JING 提交者: GitHub

create bmm op and move several api from fluid.layers to tensor (#23457)

* add gradient check to reduce ops

* add skip gradient check to reduce ops test=develop

* modify stack api test=develop

* add bmm op and move serval ops from fluid.layers to tensor test=develop
上级 071a7020
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/bmm_op.h"
#include <vector>
namespace paddle {
namespace operators {
class BmmOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of BmmOp should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of BmmOp should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of BmmOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), 3,
platform::errors::InvalidArgument(
"Input(X) of BmmOp must be 3-dimensional in BmmOp, "
"but received X's shape: [%s].",
x_dims));
PADDLE_ENFORCE_EQ(y_dims.size(), 3,
platform::errors::InvalidArgument(
"Input(Y) of BmmOp must be 3-dimensional in BmmOp, "
"but received Y's shape: [%s].",
y_dims));
PADDLE_ENFORCE_EQ(
x_dims[0], y_dims[0],
platform::errors::InvalidArgument(
"Input(X) and Input(Y) must have the same batch size in BmmOp, "
"but received X's batch size: [%s],"
"Y's batch size [%s]",
x_dims[0], y_dims[0]));
PADDLE_ENFORCE_EQ(
x_dims[2], y_dims[1],
platform::errors::InvalidArgument(
"Input(X)'s width must be equal with Input(Y)'s height in BmmOp,"
"but receive X's width: [%s],"
"Y's height: [%s].",
x_dims[2], y_dims[1]));
std::vector<int64_t> dim_out;
dim_out.push_back(x_dims[0]);
dim_out.push_back(x_dims[1]);
dim_out.push_back(y_dims[2]);
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class BmmOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The first input tensor of Bmm op.");
AddInput("Y", "(Tensor), The second input tensor of Bmm op.");
AddOutput("Out", "(Tensor), The output tensor of Bmm op.");
AddComment(R"DOC(
The Bmm operator is used to perform batched matrix multiplication
over the last two dimensions of the input tensors `X` and `Y`
which are both 3-dimentionsal.
Examples:
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
)DOC");
}
};
class BmmOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of BmmOp should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of BmmOp should not be null"));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound(
"Output(Out@GRAD) of BmmOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
class BmmOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("bmm_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(bmm, ops::BmmOp, ops::BmmOpMaker,
ops::BmmOpGradMaker<paddle::framework::OpDesc>,
ops::BmmOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bmm_grad, ops::BmmOpGrad);
REGISTER_OP_CPU_KERNEL(
bmm, ops::BmmKernel<paddle::platform::CPUDeviceContext, float>,
ops::BmmKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
bmm_grad, ops::BmmGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::BmmGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bmm_op.h"
#ifdef PADDLE_WITH_CUDA
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
bmm, ops::BmmKernel<paddle::platform::CUDADeviceContext, float>,
ops::BmmKernel<paddle::platform::CUDADeviceContext, double>,
ops::BmmKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
bmm_grad, ops::BmmGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::BmmGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::BmmGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
#endif
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_BMM_OP_H_
#define PADDLE_FLUID_OPERATORS_BMM_OP_H_
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static void ReshapeTensorIntoMatrixSequence(
framework::Tensor *x, const math::MatDescriptor &descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
x->Resize({descriptor.batch_size_, h, w});
}
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
framework::Tensor *y,
framework::Tensor *out, bool trans_x,
bool trans_y) {
auto x_dim = x->dims();
auto y_dim = y->dims();
auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, false);
auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, false);
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_, mat_dim_y.width_});
ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}
template <typename DeviceContext, typename T>
class BmmKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor &x = *context.Input<Tensor>("X");
const Tensor &y = *context.Input<Tensor>("Y");
Tensor *out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(x.dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(y.dims(), 0, false);
// auto scale = static_cast<T>(context.Attr<float>("alpha"));
blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0));
}
};
template <typename DeviceContext, typename T>
class BmmGradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext &context,
const framework::Tensor &a, bool trans_a,
const framework::Tensor &b, bool trans_b,
framework::Tensor *out) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
}
void CalcInputGrad(const framework::ExecutionContext &context,
const framework::Tensor &a, bool trans_a,
const framework::Tensor &b, bool trans_b,
framework::Tensor *out) const {
if (out == nullptr) return;
MatMul(context, a, trans_a, b, trans_b, out);
}
void Compute(const framework::ExecutionContext &context) const override {
auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y");
auto dout =
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto *dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, false, false);
framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
CalcInputGrad(context, dout, false, y, true, dx);
CalcInputGrad(context, x, true, dout, false, dy);
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
}
}
}
};
} // namespace operators
} // namespace paddle
#endif // PADDLE_FLUID_OPERATORS_BMM_OP_H_
......@@ -53,7 +53,7 @@ from .tensor.creation import ones_like #DEFINE_ALIAS
# from .tensor.creation import range #DEFINE_ALIAS
from .tensor.creation import zeros #DEFINE_ALIAS
from .tensor.creation import zeros_like #DEFINE_ALIAS
# from .tensor.creation import arrange #DEFINE_ALIAS
from .tensor.creation import arange #DEFINE_ALIAS
# from .tensor.creation import eye #DEFINE_ALIAS
from .tensor.creation import full #DEFINE_ALIAS
# from .tensor.creation import linspace #DEFINE_ALIAS
......@@ -149,6 +149,7 @@ from .tensor.math import addmm #DEFINE_ALIAS
# from .tensor.io import save #DEFINE_ALIAS
# from .tensor.io import load #DEFINE_ALIAS
from .tensor.linalg import matmul #DEFINE_ALIAS
from .tensor.linalg import bmm #DEFINE_ALIAS
from .tensor.linalg import dot #DEFINE_ALIAS
# from .tensor.linalg import einsum #DEFINE_ALIAS
from .tensor.linalg import norm #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import unittest
import numpy as np
from op_test import OpTest
class TestArangeOp(OpTest):
def setUp(self):
self.op_type = "range"
self.init_config()
self.inputs = {
'Start': np.array([self.case[0]]).astype(self.dtype),
'End': np.array([self.case[1]]).astype(self.dtype),
'Step': np.array([self.case[2]]).astype(self.dtype)
}
self.outputs = {
'Out': np.arange(self.case[0], self.case[1],
self.case[2]).astype(self.dtype)
}
def init_config(self):
self.dtype = np.float32
self.case = (0, 1, 0.2)
def test_check_output(self):
self.check_output()
class TestFloatArangeOpCase0(TestArangeOp):
def init_config(self):
self.dtype = np.float32
self.case = (0, 5, 1)
class TestInt32ArangeOpCase0(TestArangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (0, 5, 2)
class TestInt32ArangeOpCase1(TestArangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (10, 1, -2)
class TestInt32ArangeOpCase2(TestArangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (-1, -10, -2)
class TestArangeAPI(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
data = paddle.arange(0, 5, 1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(fetch_list=[data])
expected_data = np.arange(0, 5, 1).astype(np.float32)
self.assertEqual((result == expected_data).all(), True)
with fluid.program_guard(fluid.Program()):
data = paddle.arange(0.0, 5.0, 1.0, 'int32')
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(fetch_list=[data])
expected_data = np.arange(0, 5, 1).astype(np.int32)
self.assertEqual((result == expected_data).all(), True)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.tensor as tensor
from paddle.fluid import Program, program_guard
class TestBmmOp(OpTest):
def setUp(self):
self.op_type = "bmm"
X = np.random.random((10, 3, 4)).astype("float64")
Y = np.random.random((10, 4, 5)).astype("float64")
self.inputs = {'X': X, 'Y': Y}
Out = np.matmul(X, Y)
self.outputs = {'Out': Out}
def test_check_output(self):
self.check_output()
def test_checkout_grad(self):
self.check_grad(['X', 'Y'], 'Out')
class API_TestBmm(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.layers.data(
'data1', shape=[-1, 3, 4], dtype='float64')
data2 = fluid.layers.data(
'data2', shape=[-1, 4, 5], dtype='float64')
result_bmm = paddle.bmm(data1, data2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input1 = np.random.random([10, 3, 4]).astype('float64')
input2 = np.random.random([10, 4, 5]).astype('float64')
result, = exe.run(feed={"data1": input1,
"data2": input2},
fetch_list=[result_bmm])
expected_result = np.matmul(input1, input2)
self.assertTrue(np.allclose(expected_result, result))
class API_TestDygraphBmm(unittest.TestCase):
def test_out(self):
input1 = np.array([[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]],
[[3.0, 3.0, 3.0], [4.0, 4.0, 4.0]]])
input2 = np.array([[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]],
[[4.0, 4.0], [5.0, 5.0], [6.0, 6.0]]])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(input1)
y = fluid.dygraph.to_variable(input2)
out = paddle.bmm(x, y)
out_np = out.numpy()
expected_result = np.matmul(input1, input2)
self.assertTrue(np.allclose(expected_result, out_np))
if __name__ == "__main__":
unittest.main()
......@@ -18,6 +18,7 @@ from __future__ import print_function
#from .linalg import *
# TODO: define alias in tensor and framework directory
# from .creation import create_tensor #DEFINE_ALIAS
# from .creation import create_lod_tensor #DEFINE_ALIAS
# from .creation import create_random_int_lod #DEFINE_ALIAS
......@@ -32,7 +33,7 @@ from .creation import linspace #DEFINE_ALIAS
# from .creation import range #DEFINE_ALIAS
# from .creation import zeros #DEFINE_ALIAS
# from .creation import zeros_like #DEFINE_ALIAS
# from .creation import arrange #DEFINE_ALIAS
from .creation import arange #DEFINE_ALIAS
# from .creation import eye #DEFINE_ALIAS
from .creation import full # DEFINE_ALIAS
# from .creation import linspace #DEFINE_ALIAS
......@@ -136,6 +137,8 @@ from .linalg import dist #DEFINE_ALIAS
from .linalg import t #DEFINE_ALIAS
from .linalg import cross #DEFINE_ALIAS
# from .linalg import cholesky #DEFINE_ALIAS
# from .linalg import dot #DEFINE_ALIAS
from .linalg import bmm #DEFINE_ALIAS
# from .manipulation import cast #DEFINE_ALIAS
# from .manipulation import concat #DEFINE_ALIAS
# from .manipulation import expand #DEFINE_ALIAS
......@@ -152,13 +155,13 @@ from .linalg import cross #DEFINE_ALIAS
# from .manipulation import slice #DEFINE_ALIAS
# from .manipulation import split #DEFINE_ALIAS
# from .manipulation import squeeze #DEFINE_ALIAS
# from .manipulation import stack #DEFINE_ALIAS
# from .manipulation import stack #DEFINE_ALIAS
# from .manipulation import strided_slice #DEFINE_ALIAS
# from .manipulation import transpose #DEFINE_ALIAS
# from .manipulation import unique #DEFINE_ALIAS
# from .manipulation import unique_with_counts #DEFINE_ALIAS
# from .manipulation import unsqueeze #DEFINE_ALIAS
# from .manipulation import unstack #DEFINE_ALIAS
# from .manipulation import unstack #DEFINE_ALIAS
from .manipulation import flip #DEFINE_ALIAS
# from .manipulation import unbind #DEFINE_ALIAS
from .manipulation import roll #DEFINE_ALIAS
......
......@@ -24,7 +24,7 @@ from paddle.common_ops_import import *
# TODO: define functions to get create a tensor
__all__ = [
'create_tensor',
# 'create_tensor',
# 'create_lod_tensor',
# 'create_random_int_lodtensor',
# 'crop_tensor',
......@@ -37,7 +37,7 @@ __all__ = [
# 'range',
'zeros',
'zeros_like',
# 'arrange',
'arange',
'eye',
'full',
'full_like',
......@@ -536,7 +536,75 @@ def full(shape,
with device_guard(device):
out = fill_constant(shape=shape, dtype=dtype, value=fill_value, out=out)
return out
def arange(start, end, step=1, dtype=None, name=None):
"""
Return evenly spaced values within a given interval.
Values are generated within the half-open interval [start, stop) (in other words,
the interval including start but excluding stop).
Parameters:
start(float32 | float64 | int32 | int64 | Variable): Start of interval. The interval includes this value.
when start is Variable, it is a 1-D Tensor with shape [1].
end(float32 | float64 | int32 | int64 | Variable): End of interval. The interval does not include this
value, except in some cases where step is not an integer
and floating point round-off affects the length of out. When end is Variable,
it is a 1-D Tensor with shape [1].
step(float32 | float64 | int32 | int64 | Variable): Spacing between values. For any output out, this is the
distance between two adjacent values, out[i+1] - out[i].
dtype(str|core.VarDesc.VarType): the data type of the output tensor, can be float32, float64, int32, int64.
Returns: a 1-D Tensor which is evenly spaced values within a given interval. Its data type is set by dtype.
Return type: Variable
examples:
.. code-block:: python
import paddle
# expected out put: [0, 2, 4, 6, 8]
data = paddle.arange(0, 10, 2, 'int32')
#dygraph mode
import paddle
import paddle.fluid as fluid
with fluid.dygraph.guard():
x = paddle.arange(0, 6, 2)
# x: [0, 2, 4]
# x dtype: float32
"""
helper = LayerHelper("range", **locals())
if dtype is None:
dtype = 'float32'
check_dtype(dtype, 'create data type',
['float32', 'float64', 'int32', 'int64'], 'range')
dtype = convert_dtype(dtype)
if not isinstance(start, Variable):
start = fill_constant([1], dtype, start)
if not isinstance(end, Variable):
end = fill_constant([1], dtype, end)
if not isinstance(step, Variable):
step = fill_constant([1], dtype, step)
out = helper.create_variable_for_type_inference(dtype=start.dtype)
helper.append_op(
type='range',
inputs={'Start': start,
'End': end,
'Step': step},
outputs={'Out': [out]})
out.stop_gradient = True
return out
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.common_ops_import import *
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type
......@@ -26,7 +27,8 @@ __all__ = [
't',
'cross',
# 'cholesky',
# 'tensordot'
# 'tensordot',
'bmm'
]
......@@ -596,3 +598,50 @@ def cross(input, other, dim=None):
outputs={'Out': out},
attrs=attrs)
return out
def bmm(x, y, name=None):
"""
Applies batched matrix multiplication to two tensors.
Both of the two input tensors must be three-dementional and share the same batch size.
if x is a (b, m, k) tensor, y is a (b, k, n) tensor, the output will be a (b, m, n) tensor.
Args:
x (Variable): The input variable which is a Tensor or LoDTensor.
y (Variable): The input variable which is a Tensor or LoDTensor.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The product Tensor (or LoDTensor) variable.
Examples:
import paddle
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[10, 3, 4], dtype='float32')
y = fluid.layers.data(name='y', shape=[10, 4, 5], dtype='float32')
out = paddle.bmm(x, y)
# In dygraph mode:
# size input1: (2, 2, 3) and input2: (2, 3, 2)
input1 = np.array([[[1.0, 1.0, 1.0],[2.0, 2.0, 2.0]],[[3.0, 3.0, 3.0],[4.0, 4.0, 4.0]]])
input2 = np.array([[[1.0, 1.0],[2.0, 2.0],[3.0, 3.0]],[[4.0, 4.0],[5.0, 5.0],[6.0, 6.0]]])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(input1)
y = fluid.dygraph.to_variable(input2)
out = paddle.bmm(x, y)
#output size: (2, 2, 2)
#output value:
#[[[6.0, 6.0],[12.0, 12.0]],[[45.0, 45.0],[60.0, 60.0]]]
out_np = out.numpy()
"""
helper = LayerHelper('bmm', **locals())
if in_dygraph_mode():
return core.ops.bmm(x, y)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='bmm', inputs={'X': x, 'Y': y}, outputs={'Out': out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册