diff --git a/paddle/fluid/operators/addmm_op.cc b/paddle/fluid/operators/addmm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f6e6856c61588f08f58a4e92e14e2a78f63745e5 --- /dev/null +++ b/paddle/fluid/operators/addmm_op.cc @@ -0,0 +1,236 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/addmm_op.h" +#include +#include +#include +#include +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class AddMMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, + platform::errors::NotFound( + "Input(Input) of AddMMOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of AddMMOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Y"), true, + platform::errors::NotFound("Input(Y) of AddMMOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::NotFound( + "Output(Out) of AddMMOp should not be null.")); + + auto input_dims = ctx->GetInputDim("Input"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + auto ndim_input = input_dims.size(); + auto ndim_x = x_dims.size(); + auto ndim_y = y_dims.size(); + + float alpha = ctx->Attrs().Get("Alpha"); + float beta = ctx->Attrs().Get("Beta"); + + VLOG(3) << "addmm operator input.shape=" << input_dims + << " x.shape=" << x_dims << " y.shape=" << y_dims + << " beta=" << beta << " alpha=" << alpha + << " ndim_input=" << ndim_input << " ndim_x=" << ndim_x + << " ndim_y=" << ndim_y; + + PADDLE_ENFORCE_NE(framework::product(input_dims), 0, + platform::errors::PreconditionNotMet( + "The Input variable Input(%s) has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.", + ctx->Inputs("Input").front())); + + PADDLE_ENFORCE_NE(framework::product(x_dims), 0, + platform::errors::PreconditionNotMet( + "The Input variable X(%s) has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.", + ctx->Inputs("X").front())); + + PADDLE_ENFORCE_NE(framework::product(y_dims), 0, + platform::errors::PreconditionNotMet( + "The Input variable Y(%s) has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.", + ctx->Inputs("Y").front())); + // dim check + PADDLE_ENFORCE_EQ(ndim_input, 2, + platform::errors::InvalidArgument( + "The input tensor input's dimension must be 2. " + "But received input's dimension = [%s].", + ndim_input)); + PADDLE_ENFORCE_EQ(ndim_x, 2, + platform::errors::InvalidArgument( + "The input tensor x's dimension must be 2. " + "But received x's dimension = [%s].", + ndim_x)); + PADDLE_ENFORCE_EQ(ndim_y, 2, + platform::errors::InvalidArgument( + "The input tensor y's dimension must be 2. " + "But received y's dimension = [%s].", + ndim_y)); + + std::vector output_dims; + output_dims.push_back(x_dims[0]); + output_dims.push_back(y_dims[1]); + + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + ctx->ShareLoD("Input", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + + if (input_data_type == framework::DataTypeTrait::DataType() || + input_data_type == framework::DataTypeTrait::DataType()) { + customized_type_value = kMULMKLDNNINT8; + } + } +#endif + + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library, customized_type_value); + } +}; + +class AddMMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "(Tensor), tensor to be added to the final result."); + AddInput("X", "(Tensor), The first input tensor for mul."); + AddInput("Y", "(Tensor), The second input tensor for mul."); + AddOutput("Out", "(Tensor), The output tensor of addmm op."); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("Alpha", "coefficient of x*y.").SetDefault(1.0f); + AddAttr("Beta", "coefficient of input.").SetDefault(1.0f); + AddComment(R"DOC( +AddMM Operator. +This operator is used to perform matrix multiplication for input $x$ and $y$ with coefficient $alpha$. +$input$ with coefficient $beta$ is added to the final result. +The equation is: + +$$Out = alpha * x * y + beta * input$$ + +$x$ and $y$ must be two-dimensional, and $input$ can be broadcastable. +)DOC"); + } +}; + +class AddMMGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Input"), true, + platform::errors::NotFound("Input(Input) should not be null")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) should not be null")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Y"), true, + platform::errors::NotFound("Input(Y) should not be null")); + PADDLE_ENFORCE_EQ( + ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::NotFound("Input(Out@GRAD) should not be null")); + const auto& input_dims = ctx->GetInputDim("Input"); + const auto& x_dims = ctx->GetInputDim("X"); + const auto& y_dims = ctx->GetInputDim("Y"); + + auto input_grad_name = framework::GradVarName("Input"); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (ctx->HasOutput(input_grad_name)) { + ctx->SetOutputDim(input_grad_name, input_dims); + } + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +template +class AddMMOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("addmm_grad"); + retv->SetInput("Input", this->Input("Input")); + retv->SetInput("X", this->Input("X")); + retv->SetInput("Y", this->Input("Y")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(addmm, ops::AddMMOp, ops::AddMMOpMaker, + ops::AddMMOpGradMaker, + ops::AddMMOpGradMaker); + +REGISTER_OPERATOR(addmm_grad, ops::AddMMGradOp); + +REGISTER_OP_CPU_KERNEL( + addmm, ops::AddMMKernel, + ops::AddMMKernel); + +REGISTER_OP_CPU_KERNEL( + addmm_grad, ops::AddMMGradKernel, + ops::AddMMGradKernel); diff --git a/paddle/fluid/operators/addmm_op.cu b/paddle/fluid/operators/addmm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..e42d9c84f9234a756362acd67029b2ace4f6c9fb --- /dev/null +++ b/paddle/fluid/operators/addmm_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/addmm_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(addmm, ops::AddMMKernel, + ops::AddMMKernel); +REGISTER_OP_CUDA_KERNEL(addmm_grad, + ops::AddMMGradKernel, + ops::AddMMGradKernel); diff --git a/paddle/fluid/operators/addmm_op.h b/paddle/fluid/operators/addmm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..97e3ed9c1adda0082a7ce74cfb5d9b6cb78dde63 --- /dev/null +++ b/paddle/fluid/operators/addmm_op.h @@ -0,0 +1,193 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenTensor = framework::EigenTensor; + +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; + +using Tensor = framework::Tensor; + +constexpr int kMULMKLDNNINT8 = 1; + +template +class AddMMKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* x = context.Input("X"); + const Tensor* y = context.Input("Y"); + + auto input_dims = input->dims(); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + + // broadcast mode check + if (x_dims[0] != input_dims[0]) { + PADDLE_ENFORCE_EQ(input_dims[0], 1, + platform::errors::InvalidArgument( + "When x_dims[0] is not equal with input_dims[0], " + "input_dims[0] must be 1 but got %s", + input_dims[0])); + PADDLE_ENFORCE_EQ( + y_dims[1] == input_dims[1] || input_dims[1] == 1, true, + platform::errors::InvalidArgument( + "The input tensor shape mismatch, input shape=[%s], " + "x shape=[%s], y shape=[%s]", + input_dims, x_dims, y_dims)); + } + // broadcast mode check + if (y_dims[1] != input_dims[1]) { + PADDLE_ENFORCE_EQ(input_dims[1], 1, + platform::errors::InvalidArgument( + "When y_dims[1] is not equal with input_dims[0], " + "input_dims[0] must be 1 but got %s", + input_dims[1])); + PADDLE_ENFORCE_EQ( + x_dims[0] == input_dims[0] || input_dims[0] == 1, true, + platform::errors::InvalidArgument( + "The input tensor shape mismatch, input shape=[%s], " + "x shape=[%s], y shape=[%s]", + input_dims, x_dims, y_dims)); + } + // broadcast mode check + PADDLE_ENFORCE_EQ( + x_dims[1], y_dims[0], + platform::errors::InvalidArgument( + "The input tensor X's width must be equal with matrix Y' height. " + "But received X's shape = [%s], Y's shape = [%s].", + x_dims[1], y_dims[0])); + + auto* out = context.Output("Out"); + out->mutable_data({x_dims[0], y_dims[1]}, context.GetPlace()); + + float alpha = context.template Attr("Alpha"); + float beta = context.template Attr("Beta"); + + auto blas = math::GetBlas(context); + + // calc broadcast dim + Array2 bcast_dims; + bcast_dims[0] = x_dims[0] / input_dims[0]; + bcast_dims[1] = y_dims[1] / input_dims[1]; + VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]"; + // broadcast using eigen + auto eigen_input = EigenTensor::From(*input); + auto eigen_out = EigenTensor::From(*out); + auto& place = + *context.template device_context().eigen_device(); + eigen_out.device(place) = eigen_input.broadcast(bcast_dims); + + blas.GEMM(false, false, x_dims[0], y_dims[1], x_dims[1], alpha, + x->data(), x_dims[1], y->data(), y_dims[1], beta, + out->data(), y_dims[1]); + } +}; + +template +class AddMMGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto in_dims = ctx.Input("Input")->dims(); + auto* dinput = + ctx.Output(framework::GradVarName("Input")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + float alpha = ctx.Attr("Alpha"); + float beta = ctx.Attr("Beta"); + + int total_elems = 0; + + VLOG(3) << "alpha: " << alpha << " beta: " << beta; + + if (dinput != nullptr) { + dinput->set_lod(dout->lod()); + } + if (dx != nullptr) { + dx->set_lod(x->lod()); + } + if (dy != nullptr) { + dy->set_lod(y->lod()); + } + + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + if (dinput) { + dinput->mutable_data(ctx.GetPlace()); + total_elems = in_dims[0] * in_dims[1]; + auto& place = + *ctx.template device_context().eigen_device(); + auto eigen_dout = EigenTensor::From(*dout); + auto eigen_dinput = EigenTensor::From(*dinput); + + bool row_compress = in_dims[0] != dout->dims()[0]; + bool col_compress = in_dims[1] != dout->dims()[1]; + auto eigen_dinput_shape = Array2(dinput->dims()[0], dinput->dims()[1]); + + if (row_compress && col_compress) { + eigen_dinput.device(place) = + eigen_dout.sum().eval().reshape(eigen_dinput_shape); + } else if (row_compress) { + eigen_dinput.device(place) = + eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape); + } else if (col_compress) { + eigen_dinput.device(place) = + eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape); + } else { + blas.VCOPY(total_elems, dout->data(), dinput->data()); + } + + blas.SCAL(total_elems, beta, dinput->data()); + } + if (dx) { + dx->mutable_data(ctx.GetPlace()); + total_elems = x->dims()[0] * x->dims()[1]; + // dx = dout * y'. dx: M x K, dout : M x N, y : K x N + blas.MatMul(*dout, false, *y, true, dx); + blas.SCAL(total_elems, alpha, dx->data()); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + total_elems = x->dims()[1] * y->dims()[1]; + // dy = x' * dout. dy K x N, dout : M x N, x : M x K + blas.MatMul(*x, true, *dout, false, dy); + blas.SCAL(total_elems, alpha, dy->data()); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 4188e26fc9830e63381c040d17670931045b2630..c0ab35b0e753c6ca9357ed2a92d0e493167e5cca 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -39,6 +39,20 @@ struct CUBlas { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSaxpy(args...)); } + template + static void SCAL(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasSscal(args...), + platform::errors::External("dynload cublasSscal lib failed")); + } + + template + static void VCOPY(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasScopy(args...), + platform::errors::External("dynload cublasScopy lib failed")); + } + template static void GEMV(ARGS... args) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemv(args...)); @@ -92,6 +106,20 @@ struct CUBlas { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDaxpy(args...)); } + template + static void SCAL(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasDscal(args...), + platform::errors::External("dynload cublasDscal lib failed")); + } + + template + static void VCOPY(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasDcopy(args...), + platform::errors::External("dynload cublasDcopy lib failed")); + } + template static void GEMV(ARGS... args) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemv(args...)); @@ -318,6 +346,20 @@ void Blas::AXPY(int n, T alpha, const T *x, }); } +template <> +template +void Blas::SCAL(int n, const T alpha, T *x) const { + context_.CublasCall( + [&](cublasHandle_t handle) { CUBlas::SCAL(handle, n, &alpha, x, 1); }); +} + +template <> +template +void Blas::VCOPY(int n, const T *x, T *y) const { + context_.CublasCall( + [&](cublasHandle_t handle) { CUBlas::VCOPY(handle, n, x, 1, y, 1); }); +} + template <> template void Blas::GEMV(bool trans_a, int M, int N, diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index ed9b9133c6a0d7597d73a7090c41b4dc56062e24..439a51dd69588d72a8c8febd6e403cb3ea2b00fd 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -64,6 +64,10 @@ extern void *cublas_dso_handle; #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasSaxpy_v2); \ __macro(cublasDaxpy_v2); \ + __macro(cublasSscal_v2); \ + __macro(cublasDscal_v2); \ + __macro(cublasScopy_v2); \ + __macro(cublasDcopy_v2); \ __macro(cublasSgemv_v2); \ __macro(cublasDgemv_v2); \ __macro(cublasSgemm_v2); \ diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 26e2316e514a4fa305e24ebfae86fd8d4e3ba719..72609882d713414d910b42a15c3706383dc7dd4d 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -142,7 +142,7 @@ from .tensor.math import add #DEFINE_ALIAS # from .tensor.math import log1p #DEFINE_ALIAS # from .tensor.math import erf #DEFINE_ALIAS # from .tensor.math import addcmul #DEFINE_ALIAS -# from .tensor.math import addmm #DEFINE_ALIAS +from .tensor.math import addmm #DEFINE_ALIAS # from .tensor.attribute import rank #DEFINE_ALIAS # from .tensor.attribute import shape #DEFINE_ALIAS # from .tensor.io import save #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_addmm_op.py b/python/paddle/fluid/tests/unittests/test_addmm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..cb44f37225ec7c3ed510aad38593a4bccfe629b8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_addmm_op.py @@ -0,0 +1,137 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +class TestAddMMOp(OpTest): + # test basic + def setUp(self): + self.op_type = "addmm" + self.dtype = np.float64 + self.init_dtype_type() + self.inputs = { + 'Input': np.random.random((100, 1)).astype(self.dtype), + 'X': np.random.random((100, 10)).astype(self.dtype), + 'Y': np.random.random((10, 20)).astype(self.dtype), + } + self.outputs = { + 'Out': + self.inputs['Input'] + np.dot(self.inputs['X'], self.inputs['Y']) + } + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Input', 'X', 'Y'], 'Out') + + def test_check_grad_x(self): + self.check_grad(['X'], 'Out', no_grad_set=None) + + def test_check_grad_y(self): + self.check_grad(['Y'], 'Out', no_grad_set=None) + + def test_check_grad_input(self): + self.check_grad(['Input'], 'Out', no_grad_set=None) + + +class TestAddMMOpError(unittest.TestCase): + # test error + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of addmm_op must be Variable. + input = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + x2 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + self.assertRaises(TypeError, paddle.addmm, input, x1, x2) + # The input dtype of mul_op must be float32 or float64. + input = fluid.layers.data(name='input', shape=[4], dtype="int32") + x3 = fluid.layers.data(name='x3', shape=[4], dtype="int32") + x4 = fluid.layers.data(name='x4', shape=[4], dtype="int32") + self.assertRaises(TypeError, paddle.addmm, input, x3, x4) + + +class TestAddMMOp2(TestAddMMOp): + # test alpha and beta + def setUp(self): + self.op_type = "addmm" + self.dtype = np.float64 + self.init_dtype_type() + self.inputs = { + 'Input': np.random.random((20, 30)).astype(self.dtype), + 'X': np.random.random((20, 6)).astype(self.dtype), + 'Y': np.random.random((6, 30)).astype(self.dtype), + } + self.attrs = { + 'Alpha': 0.1, + 'Beta': 1.0, + } + self.outputs = {'Out': self.attrs['Beta'] * self.inputs['Input'] + \ + self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])} + + +class TestAddMMOp3(OpTest): + # test broadcast + def setUp(self): + self.op_type = "addmm" + self.dtype = np.float64 + self.init_dtype_type() + self.inputs = { + 'Input': np.random.random((1, 100)).astype(self.dtype), + 'X': np.random.random((20, 10)).astype(self.dtype), + 'Y': np.random.random((10, 100)).astype(self.dtype), + } + self.attrs = { + 'Alpha': 0.5, + 'Beta': 2.0, + } + self.outputs = {'Out': self.attrs['Beta'] * self.inputs['Input'] + \ + self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])} + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Input', 'X', 'Y'], 'Out') + + def test_check_grad_x(self): + self.check_grad(['X'], 'Out', no_grad_set=None) + + def test_check_grad_y(self): + self.check_grad(['Y'], 'Out', no_grad_set=None) + + def test_check_grad_input(self): + self.check_grad(['Input'], 'Out', no_grad_set=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 6982202be1b71032a2353df7e45a6bcbab72be6e..8059720312b54319f09726cc5219fd3f7e46e491 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -3464,6 +3464,28 @@ class TestBook(LayerTest): x=input, label=label, fg_num=fg_num, gamma=2., alpha=0.25) return (out) + def test_addmm(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = layers.data( + name='input_data', + shape=[3, 3], + append_batch_size=False, + dtype='float32') + x = layers.data( + name='x', + shape=[3, 2], + append_batch_size=False, + dtype='float32') + y = layers.data( + name='y', + shape=[2, 3], + append_batch_size=False, + dtype='float32') + + out = paddle.addmm(input=input, x=x, y=y) + return (out) + def test_retinanet_detection_output(self): with program_guard(fluid.default_main_program(), fluid.default_startup_program()): diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 64b4e3078915a5c2342bba17ef96a0b57ecf77af..fb4296f0013390b8cf03b636c1d032bf07ec8f91 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -117,7 +117,7 @@ from .math import add #DEFINE_ALIAS # from .math import log1p #DEFINE_ALIAS # from .math import erf #DEFINE_ALIAS # from .math import addcmul #DEFINE_ALIAS -# from .math import addmm #DEFINE_ALIAS +from .math import addmm #DEFINE_ALIAS # from .attribute import rank #DEFINE_ALIAS # from .attribute import shape #DEFINE_ALIAS # from .io import save #DEFINE_ALIAS diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 69f984f56561cd4e836f67a6693040cb6db3ef06..fc350bc7817e550436fb246aab1eaa9e9821b9b2 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -75,7 +75,7 @@ __all__ = [ # 'log1p', # 'erf', # 'addcmul', -# 'addmm'] + 'addmm' ] # yapf: enable. @@ -648,6 +648,7 @@ for func in [ skip_attrs_set={"x_data_format", "y_data_format", "axis" }) + """\n""" + str(func.__doc__) + def sum(input, dim=None, dtype=None, keep_dim=False, name=None): """ Computes the sum of tensor elements over the given dimension. @@ -748,6 +749,7 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None): attrs=attrs) return out + @templatedoc(op_type="sum") def elementwise_sum(inputs, name=None): """ @@ -930,3 +932,65 @@ def mm(input, mat2, out=None, name=None): type='matmul', inputs={'X': input, 'Y': mat2}, outputs={'Out': out}) return out + + +def addmm(input, x, y, alpha=1.0, beta=1.0, name=None): + """ + **addmm** + + This operator is used to perform matrix multiplication for input $x$ and $y$. + $input$ is added to the final result. + The equation is: + + .. math:: + Out = alpha * x * y + beta * input + + $Input$, $x$ and $y$ can carry the LoD (Level of Details) information, or not. But the output only shares the LoD information with input $input$. + + Args: + input (Variable): The input Tensor/LoDTensor to be added to the final result. + x (Variable): The first input Tensor/LoDTensor for matrix multiplication. + y (Variable): The second input Tensor/LoDTensor for matrix multiplication. + alpha (float): Coefficient of $x*y$. + beta (float): Coefficient of $input$. + name (str, optional): Name of the output. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default is None. + + Returns: + Variable(Tensor/LoDTensor): The output Tensor/LoDTensor of addmm op. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.fluid as fluid + + input = fluid.data(name='input', shape=[2, 2], dtype='float32') + x = fluid.data(name='x', shape=[2, 2], dtype='float32') + y = fluid.data(name='y', shape=[2, 2], dtype='float32') + out = paddle.addmm( input=input, x=x, y=y, alpha=5.0, beta=0.5 ) + + data_x = np.ones((2, 2)).astype(np.float32) + data_y = np.ones((2, 2)).astype(np.float32) + data_input = np.ones((2, 2)).astype(np.float32) + + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda() else fluid.CPUPlace() + exe = fluid.Executor(place) + results = exe.run(fluid.default_main_program(), + fetch_list=[out], feed={"input": data_input, 'x': data_x, "y": data_y}) + print( np.array(results[0]) ) + # [[10.5 10.5] + # [10.5 10.5]] + """ + inputs = {'Input': input, "X": x, "Y": y} + attrs = {'Alpha': alpha, 'Beta': beta} + + helper = LayerHelper("addmm", **locals()) + check_variable_and_dtype(x, 'Input', ['float32', 'float64'], 'addmm') + check_variable_and_dtype(x, 'X', ['float32', 'float64'], 'addmm') + check_variable_and_dtype(y, 'Y', ['float32', 'float64'], 'addmm') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type="addmm", inputs=inputs, attrs=attrs, outputs={"Out": out}) + return out