diff --git a/paddle/fluid/operators/bmm_op.cc b/paddle/fluid/operators/bmm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..deba52b52e61cb490a911574f02042a02af2270c --- /dev/null +++ b/paddle/fluid/operators/bmm_op.cc @@ -0,0 +1,165 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/bmm_op.h" +#include + +namespace paddle { +namespace operators { + +class BmmOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of BmmOp should not be null")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Y"), true, + platform::errors::NotFound("Input(Y) of BmmOp should not be null")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of BmmOp should not be null.")); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 3, + platform::errors::InvalidArgument( + "Input(X) of BmmOp must be 3-dimensional in BmmOp, " + "but received X's shape: [%s].", + x_dims)); + PADDLE_ENFORCE_EQ(y_dims.size(), 3, + platform::errors::InvalidArgument( + "Input(Y) of BmmOp must be 3-dimensional in BmmOp, " + "but received Y's shape: [%s].", + y_dims)); + PADDLE_ENFORCE_EQ( + x_dims[0], y_dims[0], + platform::errors::InvalidArgument( + "Input(X) and Input(Y) must have the same batch size in BmmOp, " + "but received X's batch size: [%s]," + "Y's batch size [%s]", + x_dims[0], y_dims[0])); + PADDLE_ENFORCE_EQ( + x_dims[2], y_dims[1], + platform::errors::InvalidArgument( + "Input(X)'s width must be equal with Input(Y)'s height in BmmOp," + "but receive X's width: [%s]," + "Y's height: [%s].", + x_dims[2], y_dims[1])); + + std::vector dim_out; + dim_out.push_back(x_dims[0]); + dim_out.push_back(x_dims[1]); + dim_out.push_back(y_dims[2]); + ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class BmmOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The first input tensor of Bmm op."); + AddInput("Y", "(Tensor), The second input tensor of Bmm op."); + AddOutput("Out", "(Tensor), The output tensor of Bmm op."); + AddComment(R"DOC( +The Bmm operator is used to perform batched matrix multiplication +over the last two dimensions of the input tensors `X` and `Y` +which are both 3-dimentionsal. + +Examples: +- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] + + )DOC"); + } +}; + +class BmmOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of BmmOp should not be null")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Y"), true, + platform::errors::NotFound("Input(Y) of BmmOp should not be null")); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::NotFound( + "Output(Out@GRAD) of BmmOp should not be null.")); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +template +class BmmOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("bmm_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput("Y", this->Input("Y")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(bmm, ops::BmmOp, ops::BmmOpMaker, + ops::BmmOpGradMaker, + ops::BmmOpGradMaker); +REGISTER_OPERATOR(bmm_grad, ops::BmmOpGrad); +REGISTER_OP_CPU_KERNEL( + bmm, ops::BmmKernel, + ops::BmmKernel); +REGISTER_OP_CPU_KERNEL( + bmm_grad, ops::BmmGradKernel, + ops::BmmGradKernel); diff --git a/paddle/fluid/operators/bmm_op.cu b/paddle/fluid/operators/bmm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..961d74b7ad42ad6fac23e436d64687a2217ee47c --- /dev/null +++ b/paddle/fluid/operators/bmm_op.cu @@ -0,0 +1,27 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/bmm_op.h" + +#ifdef PADDLE_WITH_CUDA +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + bmm, ops::BmmKernel, + ops::BmmKernel, + ops::BmmKernel); + +REGISTER_OP_CUDA_KERNEL( + bmm_grad, ops::BmmGradKernel, + ops::BmmGradKernel, + ops::BmmGradKernel); +#endif diff --git a/paddle/fluid/operators/bmm_op.h b/paddle/fluid/operators/bmm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..49104d4f08d288dd844545c0a8256fd22862ccad --- /dev/null +++ b/paddle/fluid/operators/bmm_op.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#ifndef PADDLE_FLUID_OPERATORS_BMM_OP_H_ +#define PADDLE_FLUID_OPERATORS_BMM_OP_H_ + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static void ReshapeTensorIntoMatrixSequence( + framework::Tensor *x, const math::MatDescriptor &descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + + x->Resize({descriptor.batch_size_, h, w}); +} + +static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x, + framework::Tensor *y, + framework::Tensor *out, bool trans_x, + bool trans_y) { + auto x_dim = x->dims(); + auto y_dim = y->dims(); + auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, false); + auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, false); + + out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, mat_dim_y.width_}); + + ReshapeTensorIntoMatrixSequence(x, mat_dim_x); + ReshapeTensorIntoMatrixSequence(y, mat_dim_y); +} + +template +class BmmKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor &x = *context.Input("X"); + const Tensor &y = *context.Input("Y"); + Tensor *out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + + auto mat_dim_a = math::CreateMatrixDescriptor(x.dims(), 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(y.dims(), 0, false); + + // auto scale = static_cast(context.Attr("alpha")); + blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0)); + } +}; + +template +class BmmGradKernel : public framework::OpKernel { + public: + void MatMul(const framework::ExecutionContext &context, + const framework::Tensor &a, bool trans_a, + const framework::Tensor &b, bool trans_b, + framework::Tensor *out) const { + out->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + + blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); + } + void CalcInputGrad(const framework::ExecutionContext &context, + const framework::Tensor &a, bool trans_a, + const framework::Tensor &b, bool trans_b, + framework::Tensor *out) const { + if (out == nullptr) return; + MatMul(context, a, trans_a, b, trans_b, out); + } + void Compute(const framework::ExecutionContext &context) const override { + auto x = *context.Input("X"); + auto y = *context.Input("Y"); + auto dout = + *context.Input(framework::GradVarName("Out")); + auto *dx = context.Output(framework::GradVarName("X")); + auto *dy = context.Output(framework::GradVarName("Y")); + + ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, false, false); + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } + } + + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } + } + + CalcInputGrad(context, dout, false, y, true, dx); + CalcInputGrad(context, x, true, dout, false, dy); + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + } + } + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + } + } + } +}; + +} // namespace operators +} // namespace paddle +#endif // PADDLE_FLUID_OPERATORS_BMM_OP_H_ diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 05bebd759d39379f23071d6315234c7ce44e6a8e..bbe085ede696bd59289bfe99566f38cc4e891145 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -53,7 +53,7 @@ from .tensor.creation import ones_like #DEFINE_ALIAS # from .tensor.creation import range #DEFINE_ALIAS from .tensor.creation import zeros #DEFINE_ALIAS from .tensor.creation import zeros_like #DEFINE_ALIAS -# from .tensor.creation import arrange #DEFINE_ALIAS +from .tensor.creation import arange #DEFINE_ALIAS # from .tensor.creation import eye #DEFINE_ALIAS from .tensor.creation import full #DEFINE_ALIAS # from .tensor.creation import linspace #DEFINE_ALIAS @@ -149,6 +149,7 @@ from .tensor.math import addmm #DEFINE_ALIAS # from .tensor.io import save #DEFINE_ALIAS # from .tensor.io import load #DEFINE_ALIAS from .tensor.linalg import matmul #DEFINE_ALIAS +from .tensor.linalg import bmm #DEFINE_ALIAS from .tensor.linalg import dot #DEFINE_ALIAS # from .tensor.linalg import einsum #DEFINE_ALIAS from .tensor.linalg import norm #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_arange.py b/python/paddle/fluid/tests/unittests/test_arange.py new file mode 100644 index 0000000000000000000000000000000000000000..d715744b02a010e442dfe9fb4f2409d481c8f8d9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_arange.py @@ -0,0 +1,91 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import unittest +import numpy as np +from op_test import OpTest + + +class TestArangeOp(OpTest): + def setUp(self): + self.op_type = "range" + self.init_config() + self.inputs = { + 'Start': np.array([self.case[0]]).astype(self.dtype), + 'End': np.array([self.case[1]]).astype(self.dtype), + 'Step': np.array([self.case[2]]).astype(self.dtype) + } + + self.outputs = { + 'Out': np.arange(self.case[0], self.case[1], + self.case[2]).astype(self.dtype) + } + + def init_config(self): + self.dtype = np.float32 + self.case = (0, 1, 0.2) + + def test_check_output(self): + self.check_output() + + +class TestFloatArangeOpCase0(TestArangeOp): + def init_config(self): + self.dtype = np.float32 + self.case = (0, 5, 1) + + +class TestInt32ArangeOpCase0(TestArangeOp): + def init_config(self): + self.dtype = np.int32 + self.case = (0, 5, 2) + + +class TestInt32ArangeOpCase1(TestArangeOp): + def init_config(self): + self.dtype = np.int32 + self.case = (10, 1, -2) + + +class TestInt32ArangeOpCase2(TestArangeOp): + def init_config(self): + self.dtype = np.int32 + self.case = (-1, -10, -2) + + +class TestArangeAPI(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program()): + data = paddle.arange(0, 5, 1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[data]) + expected_data = np.arange(0, 5, 1).astype(np.float32) + self.assertEqual((result == expected_data).all(), True) + + with fluid.program_guard(fluid.Program()): + data = paddle.arange(0.0, 5.0, 1.0, 'int32') + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[data]) + expected_data = np.arange(0, 5, 1).astype(np.int32) + self.assertEqual((result == expected_data).all(), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_bmm_op.py b/python/paddle/fluid/tests/unittests/test_bmm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..993ac25d8d4b638a56c9e2aa4f832f576f0b2ae7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bmm_op.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.tensor as tensor +from paddle.fluid import Program, program_guard + + +class TestBmmOp(OpTest): + def setUp(self): + self.op_type = "bmm" + X = np.random.random((10, 3, 4)).astype("float64") + Y = np.random.random((10, 4, 5)).astype("float64") + self.inputs = {'X': X, 'Y': Y} + Out = np.matmul(X, Y) + self.outputs = {'Out': Out} + + def test_check_output(self): + self.check_output() + + def test_checkout_grad(self): + self.check_grad(['X', 'Y'], 'Out') + + +class API_TestBmm(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.layers.data( + 'data1', shape=[-1, 3, 4], dtype='float64') + data2 = fluid.layers.data( + 'data2', shape=[-1, 4, 5], dtype='float64') + result_bmm = paddle.bmm(data1, data2) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([10, 3, 4]).astype('float64') + input2 = np.random.random([10, 4, 5]).astype('float64') + result, = exe.run(feed={"data1": input1, + "data2": input2}, + fetch_list=[result_bmm]) + expected_result = np.matmul(input1, input2) + self.assertTrue(np.allclose(expected_result, result)) + + +class API_TestDygraphBmm(unittest.TestCase): + def test_out(self): + input1 = np.array([[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], + [[3.0, 3.0, 3.0], [4.0, 4.0, 4.0]]]) + input2 = np.array([[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], + [[4.0, 4.0], [5.0, 5.0], [6.0, 6.0]]]) + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(input1) + y = fluid.dygraph.to_variable(input2) + out = paddle.bmm(x, y) + out_np = out.numpy() + expected_result = np.matmul(input1, input2) + self.assertTrue(np.allclose(expected_result, out_np)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index f93fa7f18603ab85ed7ab9ae94671a0d69d5fee5..d2ecadd7b1362457e2ce9449d88ff60559596b5d 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -18,6 +18,7 @@ from __future__ import print_function #from .linalg import * # TODO: define alias in tensor and framework directory + # from .creation import create_tensor #DEFINE_ALIAS # from .creation import create_lod_tensor #DEFINE_ALIAS # from .creation import create_random_int_lod #DEFINE_ALIAS @@ -32,7 +33,7 @@ from .creation import linspace #DEFINE_ALIAS # from .creation import range #DEFINE_ALIAS # from .creation import zeros #DEFINE_ALIAS # from .creation import zeros_like #DEFINE_ALIAS -# from .creation import arrange #DEFINE_ALIAS +from .creation import arange #DEFINE_ALIAS # from .creation import eye #DEFINE_ALIAS from .creation import full # DEFINE_ALIAS # from .creation import linspace #DEFINE_ALIAS @@ -136,6 +137,8 @@ from .linalg import dist #DEFINE_ALIAS from .linalg import t #DEFINE_ALIAS from .linalg import cross #DEFINE_ALIAS # from .linalg import cholesky #DEFINE_ALIAS +# from .linalg import dot #DEFINE_ALIAS +from .linalg import bmm #DEFINE_ALIAS # from .manipulation import cast #DEFINE_ALIAS # from .manipulation import concat #DEFINE_ALIAS # from .manipulation import expand #DEFINE_ALIAS @@ -152,13 +155,13 @@ from .linalg import cross #DEFINE_ALIAS # from .manipulation import slice #DEFINE_ALIAS # from .manipulation import split #DEFINE_ALIAS # from .manipulation import squeeze #DEFINE_ALIAS -# from .manipulation import stack #DEFINE_ALIAS +# from .manipulation import stack #DEFINE_ALIAS # from .manipulation import strided_slice #DEFINE_ALIAS # from .manipulation import transpose #DEFINE_ALIAS # from .manipulation import unique #DEFINE_ALIAS # from .manipulation import unique_with_counts #DEFINE_ALIAS # from .manipulation import unsqueeze #DEFINE_ALIAS -# from .manipulation import unstack #DEFINE_ALIAS +# from .manipulation import unstack #DEFINE_ALIAS from .manipulation import flip #DEFINE_ALIAS # from .manipulation import unbind #DEFINE_ALIAS from .manipulation import roll #DEFINE_ALIAS diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index be58a9dd868d3bbc1f23c0e76996a951bb80ce18..64c82accd729adce52963397fd22091496d379b5 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -24,7 +24,7 @@ from paddle.common_ops_import import * # TODO: define functions to get create a tensor __all__ = [ - 'create_tensor', + # 'create_tensor', # 'create_lod_tensor', # 'create_random_int_lodtensor', # 'crop_tensor', @@ -37,7 +37,7 @@ __all__ = [ # 'range', 'zeros', 'zeros_like', - # 'arrange', + 'arange', 'eye', 'full', 'full_like', @@ -536,7 +536,75 @@ def full(shape, with device_guard(device): out = fill_constant(shape=shape, dtype=dtype, value=fill_value, out=out) + return out + + +def arange(start, end, step=1, dtype=None, name=None): + """ + Return evenly spaced values within a given interval. + + Values are generated within the half-open interval [start, stop) (in other words, + the interval including start but excluding stop). + + Parameters: + start(float32 | float64 | int32 | int64 | Variable): Start of interval. The interval includes this value. + when start is Variable, it is a 1-D Tensor with shape [1]. + end(float32 | float64 | int32 | int64 | Variable): End of interval. The interval does not include this + value, except in some cases where step is not an integer + and floating point round-off affects the length of out. When end is Variable, + it is a 1-D Tensor with shape [1]. + step(float32 | float64 | int32 | int64 | Variable): Spacing between values. For any output out, this is the + distance between two adjacent values, out[i+1] - out[i]. + dtype(str|core.VarDesc.VarType): the data type of the output tensor, can be float32, float64, int32, int64. + + Returns: a 1-D Tensor which is evenly spaced values within a given interval. Its data type is set by dtype. + + Return type: Variable + + examples: + + .. code-block:: python + + import paddle + # expected out put: [0, 2, 4, 6, 8] + data = paddle.arange(0, 10, 2, 'int32') + #dygraph mode + import paddle + import paddle.fluid as fluid + with fluid.dygraph.guard(): + x = paddle.arange(0, 6, 2) + # x: [0, 2, 4] + # x dtype: float32 + + """ + helper = LayerHelper("range", **locals()) + + if dtype is None: + dtype = 'float32' + + check_dtype(dtype, 'create data type', + ['float32', 'float64', 'int32', 'int64'], 'range') + + dtype = convert_dtype(dtype) + if not isinstance(start, Variable): + start = fill_constant([1], dtype, start) + + if not isinstance(end, Variable): + end = fill_constant([1], dtype, end) + + if not isinstance(step, Variable): + step = fill_constant([1], dtype, step) + + out = helper.create_variable_for_type_inference(dtype=start.dtype) + + helper.append_op( + type='range', + inputs={'Start': start, + 'End': end, + 'Step': step}, + outputs={'Out': [out]}) + out.stop_gradient = True return out diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index c3ccb43e90f5cf0e3800a00c1216aa361f1affb2..70624b63b9fe9fa07dfea01fd21c8f74350f9198 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from paddle.common_ops_import import * from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type @@ -26,7 +27,8 @@ __all__ = [ 't', 'cross', # 'cholesky', - # 'tensordot' + # 'tensordot', + 'bmm' ] @@ -596,3 +598,50 @@ def cross(input, other, dim=None): outputs={'Out': out}, attrs=attrs) return out + + +def bmm(x, y, name=None): + """ + Applies batched matrix multiplication to two tensors. + + Both of the two input tensors must be three-dementional and share the same batch size. + + if x is a (b, m, k) tensor, y is a (b, k, n) tensor, the output will be a (b, m, n) tensor. + + Args: + x (Variable): The input variable which is a Tensor or LoDTensor. + y (Variable): The input variable which is a Tensor or LoDTensor. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The product Tensor (or LoDTensor) variable. + + Examples: + import paddle + import paddle.fluid as fluid + x = fluid.layers.data(name='x', shape=[10, 3, 4], dtype='float32') + y = fluid.layers.data(name='y', shape=[10, 4, 5], dtype='float32') + out = paddle.bmm(x, y) + + # In dygraph mode: + # size input1: (2, 2, 3) and input2: (2, 3, 2) + input1 = np.array([[[1.0, 1.0, 1.0],[2.0, 2.0, 2.0]],[[3.0, 3.0, 3.0],[4.0, 4.0, 4.0]]]) + input2 = np.array([[[1.0, 1.0],[2.0, 2.0],[3.0, 3.0]],[[4.0, 4.0],[5.0, 5.0],[6.0, 6.0]]]) + + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(input1) + y = fluid.dygraph.to_variable(input2) + out = paddle.bmm(x, y) + #output size: (2, 2, 2) + #output value: + #[[[6.0, 6.0],[12.0, 12.0]],[[45.0, 45.0],[60.0, 60.0]]] + out_np = out.numpy() + """ + + helper = LayerHelper('bmm', **locals()) + if in_dygraph_mode(): + return core.ops.bmm(x, y) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op(type='bmm', inputs={'X': x, 'Y': y}, outputs={'Out': out}) + return out