未验证 提交 1c08a213 编写于 作者: L littletomatodonkey 提交者: GitHub

test=develop, add addmm op (#23384)

add addmm op
上级 7b5e23c0
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/addmm_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class AddMMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
platform::errors::NotFound(
"Input(Input) of AddMMOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of AddMMOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of AddMMOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of AddMMOp should not be null."));
auto input_dims = ctx->GetInputDim("Input");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto ndim_input = input_dims.size();
auto ndim_x = x_dims.size();
auto ndim_y = y_dims.size();
float alpha = ctx->Attrs().Get<float>("Alpha");
float beta = ctx->Attrs().Get<float>("Beta");
VLOG(3) << "addmm operator input.shape=" << input_dims
<< " x.shape=" << x_dims << " y.shape=" << y_dims
<< " beta=" << beta << " alpha=" << alpha
<< " ndim_input=" << ndim_input << " ndim_x=" << ndim_x
<< " ndim_y=" << ndim_y;
PADDLE_ENFORCE_NE(framework::product(input_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable Input(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("Input").front()));
PADDLE_ENFORCE_NE(framework::product(x_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable X(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("X").front()));
PADDLE_ENFORCE_NE(framework::product(y_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable Y(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("Y").front()));
// dim check
PADDLE_ENFORCE_EQ(ndim_input, 2,
platform::errors::InvalidArgument(
"The input tensor input's dimension must be 2. "
"But received input's dimension = [%s].",
ndim_input));
PADDLE_ENFORCE_EQ(ndim_x, 2,
platform::errors::InvalidArgument(
"The input tensor x's dimension must be 2. "
"But received x's dimension = [%s].",
ndim_x));
PADDLE_ENFORCE_EQ(ndim_y, 2,
platform::errors::InvalidArgument(
"The input tensor y's dimension must be 2. "
"But received y's dimension = [%s].",
ndim_y));
std::vector<int64_t> output_dims;
output_dims.push_back(x_dims[0]);
output_dims.push_back(y_dims[1]);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("Input", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
}
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value);
}
};
class AddMMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "(Tensor), tensor to be added to the final result.");
AddInput("X", "(Tensor), The first input tensor for mul.");
AddInput("Y", "(Tensor), The second input tensor for mul.");
AddOutput("Out", "(Tensor), The output tensor of addmm op.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<float>("Alpha", "coefficient of x*y.").SetDefault(1.0f);
AddAttr<float>("Beta", "coefficient of input.").SetDefault(1.0f);
AddComment(R"DOC(
AddMM Operator.
This operator is used to perform matrix multiplication for input $x$ and $y$ with coefficient $alpha$.
$input$ with coefficient $beta$ is added to the final result.
The equation is:
$$Out = alpha * x * y + beta * input$$
$x$ and $y$ must be two-dimensional, and $input$ can be broadcastable.
)DOC");
}
};
class AddMMGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"), true,
platform::errors::NotFound("Input(Input) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Input(Out@GRAD) should not be null"));
const auto& input_dims = ctx->GetInputDim("Input");
const auto& x_dims = ctx->GetInputDim("X");
const auto& y_dims = ctx->GetInputDim("Y");
auto input_grad_name = framework::GradVarName("Input");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(input_grad_name)) {
ctx->SetOutputDim(input_grad_name, input_dims);
}
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
template <typename T>
class AddMMOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("addmm_grad");
retv->SetInput("Input", this->Input("Input"));
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(addmm, ops::AddMMOp, ops::AddMMOpMaker,
ops::AddMMOpGradMaker<paddle::framework::OpDesc>,
ops::AddMMOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(addmm_grad, ops::AddMMGradOp);
REGISTER_OP_CPU_KERNEL(
addmm, ops::AddMMKernel<paddle::platform::CPUDeviceContext, float>,
ops::AddMMKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
addmm_grad, ops::AddMMGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::AddMMGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/addmm_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(addmm, ops::AddMMKernel<plat::CUDADeviceContext, float>,
ops::AddMMKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(addmm_grad,
ops::AddMMGradKernel<plat::CUDADeviceContext, float>,
ops::AddMMGradKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <boost/preprocessor/repetition/repeat.hpp>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Array1 = Eigen::DSizes<int64_t, 1>;
using Array2 = Eigen::DSizes<int64_t, 2>;
using Tensor = framework::Tensor;
constexpr int kMULMKLDNNINT8 = 1;
template <typename DeviceContext, typename T>
class AddMMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* x = context.Input<Tensor>("X");
const Tensor* y = context.Input<Tensor>("Y");
auto input_dims = input->dims();
auto x_dims = x->dims();
auto y_dims = y->dims();
// broadcast mode check
if (x_dims[0] != input_dims[0]) {
PADDLE_ENFORCE_EQ(input_dims[0], 1,
platform::errors::InvalidArgument(
"When x_dims[0] is not equal with input_dims[0], "
"input_dims[0] must be 1 but got %s",
input_dims[0]));
PADDLE_ENFORCE_EQ(
y_dims[1] == input_dims[1] || input_dims[1] == 1, true,
platform::errors::InvalidArgument(
"The input tensor shape mismatch, input shape=[%s], "
"x shape=[%s], y shape=[%s]",
input_dims, x_dims, y_dims));
}
// broadcast mode check
if (y_dims[1] != input_dims[1]) {
PADDLE_ENFORCE_EQ(input_dims[1], 1,
platform::errors::InvalidArgument(
"When y_dims[1] is not equal with input_dims[0], "
"input_dims[0] must be 1 but got %s",
input_dims[1]));
PADDLE_ENFORCE_EQ(
x_dims[0] == input_dims[0] || input_dims[0] == 1, true,
platform::errors::InvalidArgument(
"The input tensor shape mismatch, input shape=[%s], "
"x shape=[%s], y shape=[%s]",
input_dims, x_dims, y_dims));
}
// broadcast mode check
PADDLE_ENFORCE_EQ(
x_dims[1], y_dims[0],
platform::errors::InvalidArgument(
"The input tensor X's width must be equal with matrix Y' height. "
"But received X's shape = [%s], Y's shape = [%s].",
x_dims[1], y_dims[0]));
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>({x_dims[0], y_dims[1]}, context.GetPlace());
float alpha = context.template Attr<float>("Alpha");
float beta = context.template Attr<float>("Beta");
auto blas = math::GetBlas<DeviceContext, T>(context);
// calc broadcast dim
Array2 bcast_dims;
bcast_dims[0] = x_dims[0] / input_dims[0];
bcast_dims[1] = y_dims[1] / input_dims[1];
VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]";
// broadcast using eigen
auto eigen_input = EigenTensor<T, 2>::From(*input);
auto eigen_out = EigenTensor<T, 2>::From(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
eigen_out.device(place) = eigen_input.broadcast(bcast_dims);
blas.GEMM(false, false, x_dims[0], y_dims[1], x_dims[1], alpha,
x->data<T>(), x_dims[1], y->data<T>(), y_dims[1], beta,
out->data<T>(), y_dims[1]);
}
};
template <typename DeviceContext, typename T>
class AddMMGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto in_dims = ctx.Input<framework::LoDTensor>("Input")->dims();
auto* dinput =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
float alpha = ctx.Attr<float>("Alpha");
float beta = ctx.Attr<float>("Beta");
int total_elems = 0;
VLOG(3) << "alpha: " << alpha << " beta: " << beta;
if (dinput != nullptr) {
dinput->set_lod(dout->lod());
}
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (dinput) {
dinput->mutable_data<T>(ctx.GetPlace());
total_elems = in_dims[0] * in_dims[1];
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto eigen_dout = EigenTensor<T, 2>::From(*dout);
auto eigen_dinput = EigenTensor<T, 2>::From(*dinput);
bool row_compress = in_dims[0] != dout->dims()[0];
bool col_compress = in_dims[1] != dout->dims()[1];
auto eigen_dinput_shape = Array2(dinput->dims()[0], dinput->dims()[1]);
if (row_compress && col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
} else if (row_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
} else if (col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
} else {
blas.VCOPY(total_elems, dout->data<T>(), dinput->data<T>());
}
blas.SCAL(total_elems, beta, dinput->data<T>());
}
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
total_elems = x->dims()[0] * x->dims()[1];
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
blas.MatMul(*dout, false, *y, true, dx);
blas.SCAL(total_elems, alpha, dx->data<T>());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
total_elems = x->dims()[1] * y->dims()[1];
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
blas.MatMul(*x, true, *dout, false, dy);
blas.SCAL(total_elems, alpha, dy->data<T>());
}
}
};
} // namespace operators
} // namespace paddle
...@@ -39,6 +39,20 @@ struct CUBlas<float> { ...@@ -39,6 +39,20 @@ struct CUBlas<float> {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSaxpy(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSaxpy(args...));
} }
template <typename... ARGS>
static void SCAL(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSscal(args...),
platform::errors::External("dynload cublasSscal lib failed"));
}
template <typename... ARGS>
static void VCOPY(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasScopy(args...),
platform::errors::External("dynload cublasScopy lib failed"));
}
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemv(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemv(args...));
...@@ -92,6 +106,20 @@ struct CUBlas<double> { ...@@ -92,6 +106,20 @@ struct CUBlas<double> {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDaxpy(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDaxpy(args...));
} }
template <typename... ARGS>
static void SCAL(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDscal(args...),
platform::errors::External("dynload cublasDscal lib failed"));
}
template <typename... ARGS>
static void VCOPY(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDcopy(args...),
platform::errors::External("dynload cublasDcopy lib failed"));
}
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemv(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemv(args...));
...@@ -318,6 +346,20 @@ void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x, ...@@ -318,6 +346,20 @@ void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
}); });
} }
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::SCAL(int n, const T alpha, T *x) const {
context_.CublasCall(
[&](cublasHandle_t handle) { CUBlas<T>::SCAL(handle, n, &alpha, x, 1); });
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::VCOPY(int n, const T *x, T *y) const {
context_.CublasCall(
[&](cublasHandle_t handle) { CUBlas<T>::VCOPY(handle, n, x, 1, y, 1); });
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N, void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
......
...@@ -64,6 +64,10 @@ extern void *cublas_dso_handle; ...@@ -64,6 +64,10 @@ extern void *cublas_dso_handle;
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSaxpy_v2); \ __macro(cublasSaxpy_v2); \
__macro(cublasDaxpy_v2); \ __macro(cublasDaxpy_v2); \
__macro(cublasSscal_v2); \
__macro(cublasDscal_v2); \
__macro(cublasScopy_v2); \
__macro(cublasDcopy_v2); \
__macro(cublasSgemv_v2); \ __macro(cublasSgemv_v2); \
__macro(cublasDgemv_v2); \ __macro(cublasDgemv_v2); \
__macro(cublasSgemm_v2); \ __macro(cublasSgemm_v2); \
......
...@@ -142,7 +142,7 @@ from .tensor.math import add #DEFINE_ALIAS ...@@ -142,7 +142,7 @@ from .tensor.math import add #DEFINE_ALIAS
# from .tensor.math import log1p #DEFINE_ALIAS # from .tensor.math import log1p #DEFINE_ALIAS
# from .tensor.math import erf #DEFINE_ALIAS # from .tensor.math import erf #DEFINE_ALIAS
# from .tensor.math import addcmul #DEFINE_ALIAS # from .tensor.math import addcmul #DEFINE_ALIAS
# from .tensor.math import addmm #DEFINE_ALIAS from .tensor.math import addmm #DEFINE_ALIAS
# from .tensor.attribute import rank #DEFINE_ALIAS # from .tensor.attribute import rank #DEFINE_ALIAS
# from .tensor.attribute import shape #DEFINE_ALIAS # from .tensor.attribute import shape #DEFINE_ALIAS
# from .tensor.io import save #DEFINE_ALIAS # from .tensor.io import save #DEFINE_ALIAS
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestAddMMOp(OpTest):
# test basic
def setUp(self):
self.op_type = "addmm"
self.dtype = np.float64
self.init_dtype_type()
self.inputs = {
'Input': np.random.random((100, 1)).astype(self.dtype),
'X': np.random.random((100, 10)).astype(self.dtype),
'Y': np.random.random((10, 20)).astype(self.dtype),
}
self.outputs = {
'Out':
self.inputs['Input'] + np.dot(self.inputs['X'], self.inputs['Y'])
}
def init_dtype_type(self):
pass
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input', 'X', 'Y'], 'Out')
def test_check_grad_x(self):
self.check_grad(['X'], 'Out', no_grad_set=None)
def test_check_grad_y(self):
self.check_grad(['Y'], 'Out', no_grad_set=None)
def test_check_grad_input(self):
self.check_grad(['Input'], 'Out', no_grad_set=None)
class TestAddMMOpError(unittest.TestCase):
# test error
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of addmm_op must be Variable.
input = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, paddle.addmm, input, x1, x2)
# The input dtype of mul_op must be float32 or float64.
input = fluid.layers.data(name='input', shape=[4], dtype="int32")
x3 = fluid.layers.data(name='x3', shape=[4], dtype="int32")
x4 = fluid.layers.data(name='x4', shape=[4], dtype="int32")
self.assertRaises(TypeError, paddle.addmm, input, x3, x4)
class TestAddMMOp2(TestAddMMOp):
# test alpha and beta
def setUp(self):
self.op_type = "addmm"
self.dtype = np.float64
self.init_dtype_type()
self.inputs = {
'Input': np.random.random((20, 30)).astype(self.dtype),
'X': np.random.random((20, 6)).astype(self.dtype),
'Y': np.random.random((6, 30)).astype(self.dtype),
}
self.attrs = {
'Alpha': 0.1,
'Beta': 1.0,
}
self.outputs = {'Out': self.attrs['Beta'] * self.inputs['Input'] + \
self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])}
class TestAddMMOp3(OpTest):
# test broadcast
def setUp(self):
self.op_type = "addmm"
self.dtype = np.float64
self.init_dtype_type()
self.inputs = {
'Input': np.random.random((1, 100)).astype(self.dtype),
'X': np.random.random((20, 10)).astype(self.dtype),
'Y': np.random.random((10, 100)).astype(self.dtype),
}
self.attrs = {
'Alpha': 0.5,
'Beta': 2.0,
}
self.outputs = {'Out': self.attrs['Beta'] * self.inputs['Input'] + \
self.attrs['Alpha'] * np.dot(self.inputs['X'], self.inputs['Y'])}
def init_dtype_type(self):
pass
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input', 'X', 'Y'], 'Out')
def test_check_grad_x(self):
self.check_grad(['X'], 'Out', no_grad_set=None)
def test_check_grad_y(self):
self.check_grad(['Y'], 'Out', no_grad_set=None)
def test_check_grad_input(self):
self.check_grad(['Input'], 'Out', no_grad_set=None)
if __name__ == "__main__":
unittest.main()
...@@ -3464,6 +3464,28 @@ class TestBook(LayerTest): ...@@ -3464,6 +3464,28 @@ class TestBook(LayerTest):
x=input, label=label, fg_num=fg_num, gamma=2., alpha=0.25) x=input, label=label, fg_num=fg_num, gamma=2., alpha=0.25)
return (out) return (out)
def test_addmm(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='input_data',
shape=[3, 3],
append_batch_size=False,
dtype='float32')
x = layers.data(
name='x',
shape=[3, 2],
append_batch_size=False,
dtype='float32')
y = layers.data(
name='y',
shape=[2, 3],
append_batch_size=False,
dtype='float32')
out = paddle.addmm(input=input, x=x, y=y)
return (out)
def test_retinanet_detection_output(self): def test_retinanet_detection_output(self):
with program_guard(fluid.default_main_program(), with program_guard(fluid.default_main_program(),
fluid.default_startup_program()): fluid.default_startup_program()):
......
...@@ -117,7 +117,7 @@ from .math import add #DEFINE_ALIAS ...@@ -117,7 +117,7 @@ from .math import add #DEFINE_ALIAS
# from .math import log1p #DEFINE_ALIAS # from .math import log1p #DEFINE_ALIAS
# from .math import erf #DEFINE_ALIAS # from .math import erf #DEFINE_ALIAS
# from .math import addcmul #DEFINE_ALIAS # from .math import addcmul #DEFINE_ALIAS
# from .math import addmm #DEFINE_ALIAS from .math import addmm #DEFINE_ALIAS
# from .attribute import rank #DEFINE_ALIAS # from .attribute import rank #DEFINE_ALIAS
# from .attribute import shape #DEFINE_ALIAS # from .attribute import shape #DEFINE_ALIAS
# from .io import save #DEFINE_ALIAS # from .io import save #DEFINE_ALIAS
......
...@@ -75,7 +75,7 @@ __all__ = [ ...@@ -75,7 +75,7 @@ __all__ = [
# 'log1p', # 'log1p',
# 'erf', # 'erf',
# 'addcmul', # 'addcmul',
# 'addmm'] 'addmm'
] ]
# yapf: enable. # yapf: enable.
...@@ -648,6 +648,7 @@ for func in [ ...@@ -648,6 +648,7 @@ for func in [
skip_attrs_set={"x_data_format", "y_data_format", "axis" skip_attrs_set={"x_data_format", "y_data_format", "axis"
}) + """\n""" + str(func.__doc__) }) + """\n""" + str(func.__doc__)
def sum(input, dim=None, dtype=None, keep_dim=False, name=None): def sum(input, dim=None, dtype=None, keep_dim=False, name=None):
""" """
Computes the sum of tensor elements over the given dimension. Computes the sum of tensor elements over the given dimension.
...@@ -748,6 +749,7 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None): ...@@ -748,6 +749,7 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None):
attrs=attrs) attrs=attrs)
return out return out
@templatedoc(op_type="sum") @templatedoc(op_type="sum")
def elementwise_sum(inputs, name=None): def elementwise_sum(inputs, name=None):
""" """
...@@ -930,3 +932,65 @@ def mm(input, mat2, out=None, name=None): ...@@ -930,3 +932,65 @@ def mm(input, mat2, out=None, name=None):
type='matmul', inputs={'X': input, type='matmul', inputs={'X': input,
'Y': mat2}, outputs={'Out': out}) 'Y': mat2}, outputs={'Out': out})
return out return out
def addmm(input, x, y, alpha=1.0, beta=1.0, name=None):
"""
**addmm**
This operator is used to perform matrix multiplication for input $x$ and $y$.
$input$ is added to the final result.
The equation is:
.. math::
Out = alpha * x * y + beta * input
$Input$, $x$ and $y$ can carry the LoD (Level of Details) information, or not. But the output only shares the LoD information with input $input$.
Args:
input (Variable): The input Tensor/LoDTensor to be added to the final result.
x (Variable): The first input Tensor/LoDTensor for matrix multiplication.
y (Variable): The second input Tensor/LoDTensor for matrix multiplication.
alpha (float): Coefficient of $x*y$.
beta (float): Coefficient of $input$.
name (str, optional): Name of the output. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default is None.
Returns:
Variable(Tensor/LoDTensor): The output Tensor/LoDTensor of addmm op.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
input = fluid.data(name='input', shape=[2, 2], dtype='float32')
x = fluid.data(name='x', shape=[2, 2], dtype='float32')
y = fluid.data(name='y', shape=[2, 2], dtype='float32')
out = paddle.addmm( input=input, x=x, y=y, alpha=5.0, beta=0.5 )
data_x = np.ones((2, 2)).astype(np.float32)
data_y = np.ones((2, 2)).astype(np.float32)
data_input = np.ones((2, 2)).astype(np.float32)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda() else fluid.CPUPlace()
exe = fluid.Executor(place)
results = exe.run(fluid.default_main_program(),
fetch_list=[out], feed={"input": data_input, 'x': data_x, "y": data_y})
print( np.array(results[0]) )
# [[10.5 10.5]
# [10.5 10.5]]
"""
inputs = {'Input': input, "X": x, "Y": y}
attrs = {'Alpha': alpha, 'Beta': beta}
helper = LayerHelper("addmm", **locals())
check_variable_and_dtype(x, 'Input', ['float32', 'float64'], 'addmm')
check_variable_and_dtype(x, 'X', ['float32', 'float64'], 'addmm')
check_variable_and_dtype(y, 'Y', ['float32', 'float64'], 'addmm')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="addmm", inputs=inputs, attrs=attrs, outputs={"Out": out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册