未验证 提交 c6090660 编写于 作者: S ShenLiang 提交者: GitHub

Add Matmul op (#26411)

* add matmul_v2
上级 65ac1ef6
......@@ -26,6 +26,86 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
void DotGradFunction(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
}
}
#endif
}
template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> {
public:
......@@ -84,83 +164,9 @@ class DotGradKernel : public framework::OpKernel<T> {
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
}
}
#endif
DotGradFunction<DeviceContext, T>(tensor_x, tensor_y, tensor_dout,
tensor_dx, tensor_dy, ctx);
}
};
......
......@@ -198,6 +198,11 @@ class Blas {
int K, T alpha, const T* A, const T* B, T beta, T* C,
int batchCount, int64_t strideA, int64_t strideB) const;
template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T** A, const T** B, T beta, T** C,
int batchCount) const;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
template <typename T>
void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
......
......@@ -458,6 +458,17 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#endif // CUDA_VERSION >= 9010
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
C[k]);
}
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <limits>
#include <vector>
......@@ -655,6 +656,26 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const {
#ifdef PADDLE_WITH_MKLML
const int lda = std::max((transA == CblasNoTrans) ? K : M, 1);
const int ldb = std::max((transB == CblasNoTrans) ? N : K, 1);
const int ldc = std::max(N, 1);
CBlas<T>::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A,
&lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */,
&batchCount);
#else
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
C[k]);
}
#endif
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
template <>
template <typename T>
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/matmul_v2_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
class MatMulV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matmul_v2");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "matmul_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matmul_v2");
bool trans_x = ctx->Attrs().Get<bool>("trans_x");
bool trans_y = ctx->Attrs().Get<bool>("trans_y");
std::vector<int64_t> dims_x =
paddle::framework::vectorize(ctx->GetInputDim("X"));
std::vector<int64_t> dims_y =
paddle::framework::vectorize(ctx->GetInputDim("Y"));
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
bool x_broadcasted = false, y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}
if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}
size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}
std::vector<int64_t> new_dims;
if (ndims_x >= ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto out_dims = framework::make_ddim(new_dims);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /* --> */ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "tensor of shape (d0, d1 ... M, K)");
AddInput("Y", "tensor of shape (d0, d1 ... K, N)");
AddOutput("Out", "tensor of shape (d0, d1 ... M, N)");
AddAttr<bool>("trans_x",
"Set true to transpose the last two dimensions of X before "
"doing multiplication")
.SetDefault(false);
AddAttr<bool>("trans_y",
"Set true to transpose the last two dimensions of Y before "
"doing multiplication")
.SetDefault(false);
AddComment(
R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K),
B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)).
In addition, it also follows the broadcast rule which is similar as
numpy.matmul.
)DOC");
}
};
class MatMulV2OpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul_v2");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul_v2");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "matmul_v2");
auto x_dims = context->GetInputDim("X");
auto y_dims = context->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (context->HasOutput(x_grad_name)) {
context->SetOutputDim(x_grad_name, x_dims);
}
if (context->HasOutput(y_grad_name)) {
context->SetOutputDim(y_grad_name, y_dims);
}
}
};
template <typename T>
class MatMulV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("matmul_v2_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker,
ops::MatMulV2GradOpMaker<paddle::framework::OpDesc>,
ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad);
REGISTER_OP_CPU_KERNEL(
matmul_v2, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
matmul_v2_grad,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/matmul_v2_op.h"
namespace ops = paddle::operators;
namespace plf = paddle::platform;
REGISTER_OP_CUDA_KERNEL(matmul_v2,
ops::MatMulV2Kernel<plf::CUDADeviceContext, float>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
matmul_v2_grad, ops::MatMulV2GradKernel<plf::CUDADeviceContext, float>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <functional>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/dot_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#ifdef __NVCC__
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#endif
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; }
};
template <typename DeviceContext, typename T>
void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims,
const paddle::framework::ExecutionContext& ctx) {
if (reduce_dims.empty()) {
// FIXME maybe reduce this copy operation
framework::TensorCopySync(*input, ctx.GetPlace(), output);
return;
}
#ifdef __NVCC__
auto stream = ctx.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T>(), stream);
#else
ReduceKernelFunctor<DeviceContext, T, ops::SumFunctor>(
input, output, reduce_dims, true, false, ctx)
.template apply<T>();
#endif
}
static void GetBroadcastFromDims(const int x_ndim, const std::int64_t* x_dims,
const int y_ndim, const std::int64_t* y_dims,
std::int64_t* x_bd_dims,
std::int64_t* y_bd_dims,
std::int64_t* out_bd_dims) {
const int ndim = std::max(x_ndim, y_ndim);
std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1);
std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1);
std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim);
std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim);
for (int i = 0; i < ndim; ++i) {
PADDLE_ENFORCE_EQ(
x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1,
true, platform::errors::InvalidArgument(
"Input(X) and Input(Y) has error dim."));
if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) {
out_bd_dims[i] = 0;
} else {
out_bd_dims[i] = std::max(x_bd_dims[i], y_bd_dims[i]);
}
}
}
static int64_t GetIndexMessage(const int n, const int64_t* dims,
const int64_t* index) {
int64_t sum = 0;
for (int i = 0; i < n; ++i) {
if (dims[i] > 1) {
sum = sum * dims[i] + index[i];
}
}
return sum;
}
static void IndexIncreaseFromDims(const int ndim, const int64_t* dims,
int64_t* index) {
for (int i = ndim - 1; i >= 0; --i) {
++index[i];
if (index[i] >= dims[i]) {
index[i] -= dims[i];
} else {
break;
}
}
}
template <typename DeviceContext, typename T>
void MatMulFunction(const Tensor* X, const Tensor* Y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims, Tensor* Out,
bool trans_x, bool trans_y,
const paddle::framework::ExecutionContext& ctx) {
const int x_ndim = x_dims.size();
const int y_ndim = y_dims.size();
// get data ptr
const T* x_data = X->data<T>();
const T* y_data = Y->data<T>();
if (x_ndim == 1 && y_ndim == 1) {
PADDLE_ENFORCE_EQ(X->numel(), Y->numel(),
platform::errors::InvalidArgument(
"X's numbers is not equal to Y's numbers,"
"when X/Y's dims =1"));
VLOG(3) << "MatMul's case 1";
Out->Resize({1});
Out->mutable_data<T>(ctx.GetPlace());
auto out_eigen = framework::EigenScalar<T>::From(*Out);
auto x_eigen = framework::EigenVector<T>::Flatten(*X);
auto y_eigen = framework::EigenVector<T>::Flatten(*Y);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
out_eigen.device(dev) = (x_eigen * y_eigen).sum();
return;
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (x_ndim == 1) {
const int N = X->numel();
if (trans_y) {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 1], N,
platform::errors::InvalidArgument("Input(Y) has error dim."));
} else {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 2], N,
platform::errors::InvalidArgument("Input(Y) has error dim."));
}
std::vector<std::int64_t> out_dims(y_ndim - 1);
if (trans_y) {
std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin());
} else {
std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin());
out_dims.back() = y_dims.back();
}
Out->Resize(framework::make_ddim(out_dims));
Out->mutable_data<T>(ctx.GetPlace());
if (trans_y) {
const int M = Y->numel() / N;
VLOG(3) << "MatMul's case 2";
blas.GEMV(false, M, N, 1., y_data, x_data, 0., Out->data<T>());
} else {
const int M = y_dims[y_ndim - 1];
const int batch_size = Y->numel() / (M * N);
if (batch_size == 1) {
VLOG(3) << "MatMul's case 3";
blas.GEMV(true, N, M, 1., y_data, x_data, 0., Out->data<T>());
} else {
VLOG(3) << "MatMul's case 4";
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, y_data,
x_data, 0, Out->data<T>(), batch_size, M * N, 0);
}
}
return;
}
if (y_ndim == 1) {
const int N = Y->numel();
if (trans_x) {
PADDLE_ENFORCE_EQ(
x_dims[x_ndim - 2], N,
platform::errors::InvalidArgument("Input(X) has error dim."));
} else {
PADDLE_ENFORCE_EQ(
x_dims[x_ndim - 1], N,
platform::errors::InvalidArgument("Input(X) has error dim."));
}
std::vector<std::int64_t> out_dims(x_ndim - 1);
if (trans_x) {
std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin());
out_dims.back() = x_dims.back();
} else {
std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin());
}
Out->Resize(framework::make_ddim(out_dims));
Out->mutable_data<T>(ctx.GetPlace());
if (trans_x) {
const int M = x_dims[x_ndim - 1];
const int batch_size = X->numel() / (M * N);
if (batch_size == 1) {
VLOG(3) << "MatMul's case 5";
blas.GEMV(true, N, M, 1.0f, x_data, y_data, 0.0f, Out->data<T>());
} else {
VLOG(3) << "MatMul's case 6";
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, x_data,
y_data, 0, Out->data<T>(), batch_size, M * N, 0);
}
} else {
const int M = X->numel() / N;
VLOG(3) << "MatMul's case 7";
blas.GEMV(false, M, N, 1.0f, x_data, y_data, 0.0f, Out->data<T>());
}
return;
}
const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2];
const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
if (trans_y) {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, platform::errors::InvalidArgument(
"Input(X) has error dim."));
} else {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, platform::errors::InvalidArgument(
"Input(X) has error dim."));
}
const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1];
const int ndim = std::max(x_ndim, y_ndim);
std::vector<std::int64_t> x_broadcast_dims(ndim);
std::vector<std::int64_t> y_broadcast_dims(ndim);
std::vector<std::int64_t> out_broadcast_dims(ndim);
GetBroadcastFromDims(x_ndim - 2, x_dims.data(), y_ndim - 2, y_dims.data(),
x_broadcast_dims.data(), y_broadcast_dims.data(),
out_broadcast_dims.data());
out_broadcast_dims[ndim - 2] = M;
out_broadcast_dims[ndim - 1] = N;
Out->Resize(framework::make_ddim(out_broadcast_dims));
Out->mutable_data<T>(ctx.GetPlace());
const int batch_dim = ndim - 2;
// broadcast message
const bool is_broadcast_dims = !std::equal(
x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim,
y_broadcast_dims.cbegin());
const std::int64_t x_batch_size = std::accumulate(
x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim, 1LL,
std::multiplies<std::int64_t>());
const std::int64_t y_batch_size = std::accumulate(
y_broadcast_dims.cbegin(), y_broadcast_dims.cbegin() + batch_dim, 1LL,
std::multiplies<std::int64_t>());
const std::int64_t out_batch_size = std::accumulate(
out_broadcast_dims.cbegin(), out_broadcast_dims.cbegin() + batch_dim, 1LL,
std::multiplies<std::int64_t>());
if (out_batch_size == 0) return;
if (x_batch_size == 1 && y_batch_size == 1) {
VLOG(3) << "MatMul's case 8";
blas.GEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data,
y_data, 0.0f, Out->data<T>());
} else if (x_batch_size == 1) {
if (M == 1 && trans_y) {
VLOG(3) << "MatMul's case 9";
blas.GEMV(false, y_batch_size * N, K, 1.0f, y_data, x_data, 0.0f,
Out->data<T>());
} else {
VLOG(3) << "MatMul's case 10";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f,
x_data, y_data, 0, Out->data<T>(), out_batch_size, 0,
K * N);
}
} else if (y_batch_size == 1) {
if (!trans_x) {
VLOG(3) << "MatMul's case 11";
blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans,
x_batch_size * M, N, K, 1.0f, x_data, y_data, 0.0f,
Out->data<T>());
} else {
VLOG(3) << "MatMul's case 12";
blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K,
1.0f, x_data, y_data, 0, Out->data<T>(), out_batch_size,
M * K, 0);
}
} else if (!is_broadcast_dims) {
VLOG(3) << "MatMul's case 13";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data,
y_data, 0, Out->data<T>(), out_batch_size, M * K, K * N);
} else {
// in the case, can't use stridedgemm
std::vector<const T*> x_ptr(out_batch_size);
std::vector<const T*> y_ptr(out_batch_size);
std::vector<T*> out_ptr(out_batch_size);
std::vector<std::int64_t> index(batch_dim, 0);
for (std::int64_t i = 0; i < out_batch_size; ++i) {
// using the index to get offset
const std::int64_t x_index =
GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data());
const std::int64_t y_index =
GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data());
x_ptr[i] = x_data + x_index * M * K;
y_ptr[i] = y_data + y_index * K * N;
out_ptr[i] = Out->data<T>() + i * M * N;
IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data());
}
VLOG(3) << "MatMul's case 14";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f,
x_ptr.data(), y_ptr.data(), 0.0f, out_ptr.data(),
out_batch_size);
}
}
template <typename DeviceContext, typename T>
void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x,
bool trans_y,
const paddle::framework::ExecutionContext& ctx) {
const std::vector<std::int64_t> x_dims = vectorize(X->dims());
const std::vector<std::int64_t> y_dims = vectorize(Y->dims());
MatMulFunction<DeviceContext, T>(X, Y, x_dims, y_dims, Out, trans_x, trans_y,
ctx);
}
template <typename DeviceContext, typename T>
class MatMulV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<Tensor>("X");
auto* Y = ctx.Input<Tensor>("Y");
auto* Out = ctx.Output<Tensor>("Out");
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
MatMulFunction<DeviceContext, T>(X, Y, Out, trans_x, trans_y, ctx);
}
};
template <typename DeviceContext, typename T>
class MatMulV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<Tensor>("X");
auto* Y = ctx.Input<Tensor>("Y");
auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
// get dims
std::vector<std::int64_t> x_dims = vectorize(X->dims());
std::vector<std::int64_t> y_dims = vectorize(Y->dims());
std::vector<std::int64_t> dout_dims = vectorize(dOut->dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int ndim = dout_dims.size();
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// x's or y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace());
if (dOut->numel() == 1) {
DotGradFunction<DeviceContext, T>(X, Y, dOut, dx, dy, ctx);
return;
}
}
// It is very tricky. For this broadcast, currently using the reduce sum to
// get gradient.
if (x_ndim == 1) {
x_dims.insert(x_dims.begin() + 0, 1);
x_ndim += 1;
if (trans_x)
dout_dims.push_back(1);
else
dout_dims.insert(dout_dims.begin() + ndim - 1, 1);
ndim += 1;
}
if (y_ndim == 1) {
y_dims.push_back(1);
y_ndim += 1;
if (trans_y)
dout_dims.insert(dout_dims.begin() + ndim - 1, 1);
else
dout_dims.push_back(1);
ndim += 1;
}
// the normal case
Tensor dx_help, dy_help;
if (trans_x) {
if (trans_y) {
// X'Y': dA = Y'G', dB = G'X'
if (dx)
MatMulFunction<DeviceContext, T>(Y, dOut, y_dims, dout_dims, &dx_help,
true, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(dOut, X, dout_dims, x_dims, &dy_help,
true, true, ctx);
} else {
// X'Y: dX = YG', dY = XG
if (dx)
MatMulFunction<DeviceContext, T>(Y, dOut, y_dims, dout_dims, &dx_help,
false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(X, dOut, x_dims, dout_dims, &dy_help,
false, false, ctx);
}
} else {
if (trans_y) {
// XY': dX = GY, dY = G'X
if (dx)
MatMulFunction<DeviceContext, T>(dOut, Y, dout_dims, y_dims, &dx_help,
false, false, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(dOut, X, dout_dims, x_dims, &dy_help,
true, false, ctx);
} else {
// XY: dX = GY', dY = X'G
if (dx)
MatMulFunction<DeviceContext, T>(dOut, Y, dout_dims, y_dims, &dx_help,
false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(X, dOut, x_dims, dout_dims, &dy_help,
true, false, ctx);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
std::fill(dx_broadcast_dims.data(),
dx_broadcast_dims.data() + ndim - x_ndim, 1);
std::fill(dy_broadcast_dims.data(),
dy_broadcast_dims.data() + ndim - y_ndim, 1);
std::copy(x_dims.data(), x_dims.data() + x_ndim,
dx_broadcast_dims.data() + ndim - x_ndim);
std::copy(y_dims.data(), y_dims.data() + y_ndim,
dy_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> dx_reduce_dims;
std::vector<int> dy_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) {
dx_reduce_dims.push_back(idx);
}
if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) {
dy_reduce_dims.push_back(idx);
}
}
// reduce sum to get grad by ReduceSum
if (dx) {
dx->Resize(dx_help.dims());
ReduceSumForMatmulGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
ctx);
dx->Resize(X->dims());
}
if (dy) {
dy->Resize(dy_help.dims());
ReduceSumForMatmulGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
ctx);
dy->Resize(Y->dims());
}
}
};
} // namespace operators
} // namespace paddle
......@@ -26,7 +26,7 @@ import six
import paddle
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant, NumpyArrayInitializer
from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program
from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program, _varbase_creator
from .. import dygraph_utils
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
......@@ -5033,6 +5033,7 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
return out
@deprecated(since="2.0.0", update_to="paddle.matmul")
def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
"""
Applies matrix multiplication to two tensors.
......@@ -5104,7 +5105,65 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
y = fluid.layers.data(name='y', shape=[3, 2], dtype='float32')
out = fluid.layers.matmul(x, y, True, True)
"""
return paddle.matmul(x, y, transpose_x, transpose_y, alpha, name)
attrs = {
'transpose_X': transpose_x,
'transpose_Y': transpose_y,
'alpha': float(alpha),
}
if in_dygraph_mode():
out = _varbase_creator(dtype=x.dtype)
core.ops.matmul(x, y, out, 'transpose_X', transpose_x, 'transpose_Y',
transpose_y, 'alpha', float(alpha))
return out
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'matmul')
x_shape = list(x.shape)
y_shape = list(y.shape)
if len(x_shape) == 1:
x_shape = [1] + x_shape
if len(y_shape) == 1:
y_shape = y_shape + [1]
# check the inner 2 dimensions
if transpose_x:
x_shape[-2], x_shape[-1] = x_shape[-1], x_shape[-2]
if transpose_y:
y_shape[-2], y_shape[-1] = y_shape[-1], y_shape[-2]
if x_shape[-1] != y_shape[-2]:
assert (x_shape[-1] == -1) or (y_shape[-2] == -1), \
"After performing an optional transpose, Input X's width should be " \
"equal to Y's width for multiplication " \
"prerequisites. But received X's shape: %s, Y's shape: %s\n" % \
(x_shape, y_shape)
if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
# don't check neg shape
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]:
raise ValueError(
"When the matrix is larger than 2 dimensions, the higher "
"dimensional values of the two matrices need to be equal. "
"But received x_shape[%d] != y_shape[%d]. X's shape: %s, "
"Y's shape: %s.\n" % (i, i, x_shape, y_shape))
__check_input(x, y)
helper = LayerHelper('matmul', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='matmul',
inputs={'X': x,
'Y': y},
outputs={'Out': out},
attrs=attrs)
return out
def topk(input, k, name=None):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
"""Reference forward implementation using np.matmul."""
# np.matmul does not support the transpose flags, so we manually
# transpose X and Y appropriately.
if transpose_X:
if X.ndim == 1:
X = X.reshape((X.size, ))
elif X.ndim == 2:
X = X.T
else:
dim = [i for i in range(len(X.shape))]
dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
X = np.transpose(X, tuple(dim))
if transpose_Y:
if Y.ndim == 1:
Y = Y.reshape((Y.size, ))
else:
dim = [i for i in range(len(Y.shape))]
dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
Y = np.transpose(Y, tuple(dim))
Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float64")
return Out
class TestMatMulV2Op(OpTest):
"""
case 1
"""
def config(self):
self.x_shape = (100, )
self.y_shape = (100, )
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
def setUp(self):
self.config()
self.op_type = "matmul_v2"
x = np.random.random(self.x_shape).astype(self.dtype)
y = np.random.random(self.y_shape).astype(self.dtype)
result = reference_matmul(x, y, self.trans_x, self.trans_y)
self.inputs = {
'X': x,
'Y': y,
}
self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y}
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out')
class TestMatMuklOp2(TestMatMulV2Op):
"""
case 2
"""
def config(self):
self.x_shape = (100, )
self.y_shape = (1, 3, 2, 100)
self.trans_x = False
self.trans_y = True
self.dtype = "float64"
class TestMatMuklOp3(TestMatMulV2Op):
"""
case 3
"""
def config(self):
self.x_shape = (100, )
self.y_shape = (1, 1, 100, 2)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp4(TestMatMulV2Op):
"""
case 4
"""
def config(self):
self.x_shape = (100, )
self.y_shape = (1, 2, 100, 2)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp5(TestMatMulV2Op):
"""
case 5
"""
def config(self):
self.x_shape = (1, 1, 100, 2)
self.y_shape = (100, )
self.trans_x = True
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp6(TestMatMulV2Op):
"""
case 6
"""
def config(self):
self.x_shape = (1, 2, 100, 1)
self.y_shape = (100, )
self.trans_x = True
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp7(TestMatMulV2Op):
"""
case 7
"""
def config(self):
self.x_shape = (1, 2, 1, 100)
self.y_shape = (100, )
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp8(TestMatMulV2Op):
"""
case 8
"""
def config(self):
self.x_shape = (1, 1, 2, 100)
self.y_shape = (1, 1, 100, 2)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp9(TestMatMulV2Op):
"""
case 9
"""
def config(self):
self.x_shape = (1, 1, 1, 100)
self.y_shape = (2, 1, 2, 100)
self.trans_x = False
self.trans_y = True
self.dtype = "float64"
class TestMatMuklOp10(TestMatMulV2Op):
"""
case 10
"""
def config(self):
self.x_shape = (1, 1, 2, 100)
self.y_shape = (1, 2, 100, 2)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp11(TestMatMulV2Op):
"""
case 11
"""
def config(self):
self.x_shape = (2, 1, 2, 100)
self.y_shape = (1, 1, 100, 2)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp12(TestMatMulV2Op):
"""
case 12
"""
def config(self):
self.x_shape = (2, 1, 100, 2)
self.y_shape = (1, 1, 100, 2)
self.trans_x = True
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp13(TestMatMulV2Op):
"""
case 13
"""
def config(self):
self.x_shape = (2, 2, 100, 2)
self.y_shape = (2, 2, 100, 2)
self.trans_x = True
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp14(TestMatMulV2Op):
"""
case 14_1
"""
def config(self):
self.x_shape = (3, 1, 1, 100, 2)
self.y_shape = (1, 2, 2, 100, 2)
self.trans_x = True
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp15(TestMatMulV2Op):
"""
case 14_2
"""
def config(self):
self.x_shape = (3, 1, 1, 2, 100)
self.y_shape = (1, 2, 2, 100, 1)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp16(TestMatMulV2Op):
"""
case 16 : to check the gradient for special case
"""
def config(self):
self.x_shape = (100)
self.y_shape = (1, 2, 2, 100, 1)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp17(TestMatMulV2Op):
"""
case 17 : to check the gradient for special case
"""
def config(self):
self.x_shape = (2, 1, 100)
self.y_shape = (100)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMulV2API(unittest.TestCase):
def setUp(self):
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float32")
input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float32")
result = paddle.matmul(input_x, input_y)
x_np = np.random.random([4, 3]).astype("float32")
y_np = np.random.random([3, 4]).astype("float32")
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input_x": x_np,
"input_y": y_np},
fetch_list=[result])
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
def test_dygraph(self):
for place in self.places:
with fluid.dygraph.guard(place):
input_x = np.random.random([4, 3]).astype("float64")
input_y = np.random.random([3, 4]).astype("float64")
x = paddle.to_tensor(input_x)
y = paddle.to_tensor(input_y)
result = paddle.matmul(x, y)
if __name__ == "__main__":
unittest.main()
......@@ -35,135 +35,134 @@ __all__ = [
]
def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
"""
:alias_main: paddle.matmul
:alias: paddle.matmul,paddle.tensor.matmul,paddle.tensor.linalg.matmul
Applies matrix multiplication to two tensors. `matmul` follows
the complete broadcast rules,
and its behavior is consistent with `np.matmul`.
Applies matrix multiplication to two tensors.
Currently, the input tensors' rank can be any, but when the rank of any
inputs is bigger than 3, this two inputs' rank should be equal.
Currently, the input tensors' number of dimensions can be any, `matmul` can be used to
achieve the `dot`, `matmul` and `batchmatmul`.
The actual behavior depends on the shapes of :math:`x`, :math:`y` and the
flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically:
- If a transpose flag is specified, the last two dimensions of the tensor
are transposed. If the tensor is rank-1 of shape :math:`[D]`, then for
:math:`x` it is treated as :math:`[1, D]` in nontransposed form and as
:math:`[D, 1]` in transposed form, whereas for :math:`y` it is the
opposite: It is treated as :math:`[D, 1]` in nontransposed form and as
:math:`[1, D]` in transposed form.
- After transpose, the two tensors are 2-D or n-D and matrix multiplication
performs in the following way.
- If both are 2-D, they are multiplied like conventional matrices.
- If either is n-D, it is treated as a stack of matrices residing in the
last two dimensions and a batched matrix multiply supporting broadcast
applies on the two tensors.
Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and
nontransposed, the prepended or appended dimension :math:`1` will be
removed after matrix multiplication.
are transposed. If the tensor is ndim-1 of shape, the transpose is invalid. If the tensor
is ndim-1 of shape :math:`[D]`, then for :math:`x` it is treated as :math:`[1, D]`, whereas
for :math:`y` it is the opposite: It is treated as :math:`[D, 1]`.
The multiplication behavior depends on the dimensions of `x` and `y`. Specifically:
- If both tensors are 1-dimensional, the dot product result is obtained.
- If both tensors are 2-dimensional, the matrix-matrix product is obtained.
- If the `x` is 1-dimensional and the `y` is 2-dimensional,
a `1` is prepended to its dimension in order to conduct the matrix multiply.
After the matrix multiply, the prepended dimension is removed.
- If the `x` is 2-dimensional and `y` is 1-dimensional,
the matrix-vector product is obtained.
- If both arguments are at least 1-dimensional and at least one argument
is N-dimensional (where N > 2), then a batched matrix multiply is obtained.
If the first argument is 1-dimensional, a 1 is prepended to its dimension
in order to conduct the batched matrix multiply and removed after.
If the second argument is 1-dimensional, a 1 is appended to its
dimension for the purpose of the batched matrix multiple and removed after.
The non-matrix (exclude the last two dimensions) dimensions are
broadcasted according the broadcast rule.
For example, if input is a (j, 1, n, m) tensor and the other is a (k, m, p) tensor,
out will be a (j, k, n, p) tensor.
Args:
x (Variable): The input variable which is a Tensor or LoDTensor.
y (Variable): The input variable which is a Tensor or LoDTensor.
x (Tensor): The input tensor which is a Tensor.
y (Tensor): The input tensor which is a Tensor.
transpose_x (bool): Whether to transpose :math:`x` before multiplication.
transpose_y (bool): Whether to transpose :math:`y` before multiplication.
alpha (float): The scale of output. Default 1.0.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The product Tensor (or LoDTensor) variable.
Tensor: The output Tensor.
Examples:
.. code-block:: python
# Examples to clarify shapes of the inputs and output
# x: [B, ..., M, K], y: [B, ..., K, N]
# paddle.matmul(x, y) # out: [B, ..., M, N]
.. code-block:: python
# x: [B, M, K], y: [B, K, N]
# paddle.matmul(x, y) # out: [B, M, N]
import paddle
import numpy as np
# x: [B, M, K], y: [K, N]
# paddle.matmul(x, y) # out: [B, M, N]
paddle.disable_static()
# vector * vector
x_data = np.random.random([10]).astype(np.float32)
y_data = np.random.random([10]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.matmul(x, y)
print(z.numpy().shape)
# [1]
# x: [M, K], y: [K, N]
# paddle.matmul(x, y) # out: [M, N]
# matrix * vector
x_data = np.random.random([10, 5]).astype(np.float32)
y_data = np.random.random([5]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.matmul(x, y)
print(z.numpy().shape)
# [10]
# x: [B, M, K], y: [K]
# paddle.matmul(x, y) # out: [B, M]
# batched matrix * broadcasted vector
x_data = np.random.random([10, 5, 2]).astype(np.float32)
y_data = np.random.random([2]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.matmul(x, y)
print(z.numpy().shape)
# [10, 5]
# x: [K], y: [K]
# paddle.matmul(x, y) # out: [1]
# batched matrix * batched matrix
x_data = np.random.random([10, 5, 2]).astype(np.float32)
y_data = np.random.random([10, 2, 5]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.matmul(x, y)
print(z.numpy().shape)
# [10, 5, 5]
# x: [M], y: [N]
# paddle.matmul(x, y, True, True) # out: [M, N]
# batched matrix * broadcasted matrix
x_data = np.random.random([10, 1, 5, 2]).astype(np.float32)
y_data = np.random.random([1, 3, 2, 5]).astype(np.float32)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
z = paddle.matmul(x, y)
print(z.numpy().shape)
# [10, 3, 5, 5]
import paddle
import paddle.fluid as fluid
x = fluid.data(name='x', shape=[2, 3], dtype='float32')
y = fluid.data(name='y', shape=[3, 2], dtype='float32')
out = paddle.matmul(x, y, True, True)
"""
op_type = 'matmul_v2'
if in_dygraph_mode():
op = getattr(core.ops, op_type)
return op(x, y, 'trans_x', transpose_x, 'trans_y', transpose_y)
attrs = {
'transpose_X': transpose_x,
'transpose_Y': transpose_y,
'alpha': float(alpha),
'trans_x': transpose_x,
'trans_y': transpose_y,
}
if in_dygraph_mode():
out = _varbase_creator(dtype=x.dtype)
core.ops.matmul(x, y, out, 'transpose_X', transpose_x, 'transpose_Y',
transpose_y, 'alpha', float(alpha))
return out
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'matmul')
x_shape = list(x.shape)
y_shape = list(y.shape)
if len(x_shape) == 1:
x_shape = [1] + x_shape
if len(y_shape) == 1:
y_shape = y_shape + [1]
# check the inner 2 dimensions
if transpose_x:
x_shape[-2], x_shape[-1] = x_shape[-1], x_shape[-2]
if transpose_y:
y_shape[-2], y_shape[-1] = y_shape[-1], y_shape[-2]
if x_shape[-1] != y_shape[-2]:
assert (x_shape[-1] == -1) or (y_shape[-2] == -1), \
"After performing an optional transpose, Input X's width should be " \
"equal to Y's width for multiplication " \
"prerequisites. But received X's shape: %s, Y's shape: %s\n" % \
(x_shape, y_shape)
if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
# don't check neg shape
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]:
raise ValueError(
"When the matrix is larger than 2 dimensions, the higher "
"dimensional values of the two matrices need to be equal. "
"But received x_shape[%d] != y_shape[%d]. X's shape: %s, "
"Y's shape: %s.\n" % (i, i, x_shape, y_shape))
check_variable_and_dtype(val, name, ['float32', 'float64'],
'matmul')
__check_input(x, y)
helper = LayerHelper('matmul', **locals())
helper = LayerHelper('matmul_v2', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='matmul',
type='matmul_v2',
inputs={'X': x,
'Y': y},
outputs={'Out': out},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册