未验证 提交 78a3d837 编写于 作者: A Aurelius84 提交者: GitHub

Add match_matrix_tensor op (#18525)

* add matrch_matrix_tensor op test=develop

* fix ignore unittest if with_mkl=off test=develop

* clean code and rm is_test param test=develop

* modify API.spec test=develop

* rm useless code in search_compute.h test=develop

* modify api.spec test=develop

* modify default_grad.spec test=develop

* Add API test code test=develop

* clean code in search_computer.h

* modify PADDLE_ENFORCE and clean search_compute.h test=develop

* fix code style test=develop
上级 5b6673c4
......@@ -283,6 +283,7 @@ paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defau
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', '4d83ba6b971cfd590493b0925b3e081e'))
paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6'))
paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35'))
paddle.fluid.layers.match_matrix_tensor (ArgSpec(args=['x', 'y', 'channel_num', 'act', 'param_attr', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, 'float32', None)), ('document', 'b6ea7d4ddeacae85e37d1e47d5262948'))
paddle.fluid.layers.filter_by_instag (ArgSpec(args=['ins', 'ins_tag', 'filter_tag', 'is_lod'], varargs=None, keywords=None, defaults=None), ('document', '7703a2088af8de4128b143ff1164ca4a'))
paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924'))
......
......@@ -8,6 +8,7 @@ fused_embedding_seq_pool
gru
lrn
lstm_unit
match_matrix_tensor
max_pool2d_with_index
max_pool3d_with_index
maxout
......
......@@ -50,6 +50,7 @@ endif()
SET(OP_ONLY_MKL "")
if (NOT WITH_MKL)
SET(OP_ONLY_MKL ${OP_ONLY_MKL} match_matrix_tensor_op)
SET(OP_ONLY_MKL ${OP_ONLY_MKL} var_conv_2d_op)
endif()
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include <iomanip>
#include <iostream>
#include <vector>
#include "paddle/fluid/operators/match_matrix_tensor_op.h"
#include "paddle/fluid/operators/search_compute.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
void MatchMatrixTensorOP::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"X(Input) of MatchMatrix should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
"Y(Input) of MatchMatrix should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
"W(Input) of MatchMatrix should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Out(Output) of MatchMatrix should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Tmp"), true,
"Tmp(Output) of MatchMatrix should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of Input(X) can't be less than 2.");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(y_dims.size(), 2,
"The rank of Input(Y) can't be less than 2.");
auto w_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(w_dims.size(), 3UL, "W should be 3-D tensor");
int dim_t = ctx->Attrs().Get<int>("dim_t");
PADDLE_ENFORCE_EQ(w_dims[0], x_dims[1],
"W 's shape must satisfy: W[0] = X[1]");
PADDLE_ENFORCE_EQ(w_dims[1], dim_t, "W 's shape must satisfy: W[1] = dim_t");
PADDLE_ENFORCE_EQ(w_dims[2], y_dims[1],
"W 's shape must satisfy: W[2] = Y[1]");
int out_dim_0 = -1;
int tmp_dim_0 = -1;
if (ctx->IsRuntime()) {
framework::Variable* x_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
const auto& x_lod = x_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE_EQ(x_lod.empty(), false, "The Input(X) must hold lod info.");
const auto& x_lod_0 = x_lod[0];
PADDLE_ENFORCE_GE(x_lod_0.size(), 2,
"The Input(X)'s lod info is corrupted.");
PADDLE_ENFORCE_EQ(
x_dims[0], static_cast<int64_t>(x_lod_0.back()),
"The Input(X)'s lod info mismatches the actual tensor shape.");
framework::Variable* y_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Y")[0]);
const auto& y_lod = y_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE_EQ(y_lod.empty(), false, "The Input(Y) must hold lod info.");
const auto& y_lod_0 = y_lod[0];
PADDLE_ENFORCE_GE(y_lod_0.size(), 2,
"The Input(Y)'s lod info is corrupted.");
PADDLE_ENFORCE_EQ(
y_dims[0], static_cast<int64_t>(y_lod_0.back()),
"The Input(Y)'s lod info mismatches the actual tensor shape.");
PADDLE_ENFORCE_EQ(x_lod_0.size(), y_lod_0.size(),
"The Length of X and Y must be equal.");
out_dim_0 = 0;
for (size_t i = 1; i < x_lod_0.size(); i++) {
int x_len = x_lod_0[i] - x_lod_0[i - 1];
int y_len = y_lod_0[i] - y_lod_0[i - 1];
out_dim_0 += (x_len * y_len);
}
out_dim_0 *= dim_t;
tmp_dim_0 = x_dims[0] * dim_t * x_dims[1];
} else {
// compile time
framework::VarDesc* x_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("X")[0]);
PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1);
framework::VarDesc* y_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Y")[0]);
PADDLE_ENFORCE_GE(y_desc->GetLoDLevel(), 1);
}
std::vector<int64_t> out_dims_vec{out_dim_0};
out_dims_vec.push_back(1);
std::vector<int64_t> tmp_dims_vec{tmp_dim_0};
tmp_dims_vec.push_back(1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec));
ctx->SetOutputDim("Tmp", framework::make_ddim(tmp_dims_vec));
}
void MatchMatrixTensorOpGrad::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
"Input(Y) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
"Input(W) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) of SequencePadGradOp should not be null.");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
}
if (ctx->HasOutput(framework::GradVarName("Y"))) {
ctx->SetOutputDim(framework::GradVarName("Y"), ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", /*->*/ framework::GradVarName("Y"));
}
if (ctx->HasOutput(framework::GradVarName("W"))) {
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
}
}
void MatchMatrixTensorOpMaker::Make() {
AddInput("X",
"X (LoDTensor, default LoDTensor<float>) Input variable which "
"should contain lod information.");
AddInput("Y",
"Y (LoDTensor, default LoDTensor<float>) Input variable which "
"should contain lod information.");
AddInput("W", "W (Tensor), The weight of X and Y.");
AddAttr<int>("dim_t", "the dim of W").SetDefault(1);
AddOutput("Out",
"(LoDTensor, default LoDTensor<float>) Output variable which "
"is X * W * Y");
AddOutput("Tmp",
"(LoDTensor, default LoDTensor<float>) tmp variable which is "
"used for X * W");
AddComment(R"DOC(
Match Matrix Tensor Operator
This operator calculate X * W * Y, only support 2-D for X and Y.
the output is a level-1 LodTensor:
level_0: dim_t
NOTE: only support 'float32' data type now.
)DOC");
}
template <typename DeviceContext, typename T>
class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* y = ctx.Input<LoDTensor>("Y");
auto* w = ctx.Input<Tensor>("W");
auto* out = ctx.Output<LoDTensor>("Out");
auto* tmp = ctx.Output<LoDTensor>("Tmp");
int dim_t = ctx.Attr<int>("dim_t");
int dim_in = x->dims()[1];
const auto& offset_l = x->lod()[0];
const auto& offset_r = y->lod()[0];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
top_size += dim_t * len_l * len_r;
top_offset.push_back(top_size);
}
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
memset(out_data, 0.0, out->dims()[0] * out->dims()[1] * sizeof(T));
auto* bottom_l_data = x->data<T>();
auto* bottom_r_data = y->data<T>();
auto* t_data = w->data<T>();
auto* bottom_l_trans_data = tmp->mutable_data<T>(ctx.GetPlace());
memset(bottom_l_trans_data, 0.0,
tmp->dims()[0] * tmp->dims()[1] * sizeof(T));
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
call_gemm(blas, CblasNoTrans, CblasNoTrans, x->dims()[0], dim_t * dim_in,
dim_in, 1.0f, bottom_l_data, t_data, 0.0f, bottom_l_trans_data);
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
for (int t = 0; t < dim_t; t++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
auto* top_data = out_data + top_offset[b] + t * len_l * len_r;
const auto* l_t_data =
bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in;
const auto* r_data = bottom_r_data + offset_r[b] * dim_in;
auto blas_2 = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
call_gemm_with_lda(blas_2, CblasNoTrans, CblasTrans, len_l, len_r,
dim_in, 1.0f, l_t_data, r_data, 0.0f, top_data,
dim_t * dim_in);
}
}
framework::LoD out_lod;
out_lod.push_back(top_offset);
out->set_lod(out_lod);
}
};
template <typename DeviceContext, typename T>
class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* y = ctx.Input<LoDTensor>("Y");
auto* w = ctx.Input<Tensor>("W");
auto* tmp = ctx.Input<LoDTensor>("Tmp");
int dim_t = ctx.Attr<int>("dim_t");
int dim_in = x->dims()[1];
const auto& offset_l = x->lod()[0];
const auto& offset_r = y->lod()[0];
std::vector<int> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
top_size += dim_t * len_l * len_r;
top_offset.push_back(top_size);
}
auto* bottom_l_data = x->data<T>();
auto* bottom_r_data = y->data<T>();
auto* bottom_l_trans_data = tmp->data<T>();
auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto* d_y = ctx.Output<LoDTensor>(framework::GradVarName("Y"));
Tensor tmp_grad;
tmp_grad.Resize(tmp->dims());
auto* d_tmp_data = tmp_grad.mutable_data<T>(ctx.GetPlace());
auto* top_diff = d_out->data<T>();
auto* bottom_l_diff = d_x->mutable_data<T>(ctx.GetPlace());
auto* bottom_r_diff = d_y->mutable_data<T>(ctx.GetPlace());
auto* bottom_l_trans_diff = const_cast<T*>(d_tmp_data);
memset(bottom_l_diff, 0.0, x->dims()[0] * x->dims()[1] * sizeof(T));
memset(bottom_r_diff, 0.0, y->dims()[0] * y->dims()[1] * sizeof(T));
memset(bottom_l_trans_diff, 0.0,
tmp->dims()[0] * tmp->dims()[1] * sizeof(T));
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
for (int t = 0; t < dim_t; t++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
for (int i = 0; i < len_l; i++) {
for (int j = 0; j < len_r; j++) {
auto diff =
top_diff[top_offset[b] + t * len_l * len_r + i * len_r + j];
auto* l_trans_data = bottom_l_trans_data +
(offset_l[b] + i) * dim_in * dim_t +
t * dim_in;
auto* l_trans_diff = bottom_l_trans_diff +
(offset_l[b] + i) * dim_in * dim_t +
t * dim_in;
auto* r_data = bottom_r_data + (offset_r[b] + j) * dim_in;
auto* r_diff = bottom_r_diff + (offset_r[b] + j) * dim_in;
if (diff != 0.0) {
sse_axpy(r_data, l_trans_diff, dim_in, diff);
sse_axpy(l_trans_data, r_diff, dim_in, diff);
}
}
}
}
}
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
auto* t_data = w->data<T>();
auto* d_w = ctx.Output<Tensor>(framework::GradVarName("W"));
auto* t_diff = d_w->mutable_data<T>(ctx.GetPlace());
memset(t_diff, 0.0, w->dims()[0] * w->dims()[1] * w->dims()[2] * sizeof(T));
// bottom_diff
call_gemm(blas, CblasNoTrans, CblasTrans, x->dims()[0], dim_in,
dim_t * dim_in, 1.0f, bottom_l_trans_diff, t_data, 1.0f,
bottom_l_diff);
// t_diff
call_gemm(blas, CblasTrans, CblasNoTrans, dim_in, dim_t * dim_in,
x->dims()[0], 1.0f, bottom_l_data, bottom_l_trans_diff, 1.0f,
t_diff);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(match_matrix_tensor, ops::MatchMatrixTensorOP,
ops::MatchMatrixTensorOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(match_matrix_tensor_grad, ops::MatchMatrixTensorOpGrad);
REGISTER_OP_CPU_KERNEL(match_matrix_tensor,
ops::CPUMatchMatrixTensorOPKernel<
paddle::platform::CPUDeviceContext, float>);
// ops::CPUMatchMatrixTensorOPKernel<paddle::platform::CPUDeviceContext,
// double>
REGISTER_OP_CPU_KERNEL(match_matrix_tensor_grad,
ops::CPUMatchMatrixTensorOPGradKernel<
paddle::platform::CPUDeviceContext, float>);
// ops::CPUMatchMatrixTensorOPGradKernel<paddle::platform::CPUDeviceContext,
// double>
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MatchMatrixTensorOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
};
class MatchMatrixTensorOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
};
class MatchMatrixTensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <immintrin.h>
#include <cfloat>
#include <cmath>
#include <cstring>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/mklml.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
template <typename DeviceContext, typename T>
void call_gemm(const math::BlasT<DeviceContext, T>& blas,
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
}
template <typename T>
void call_gemm(const framework::ExecutionContext& ctx,
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
}
template <typename DeviceContext, typename T>
void call_gemm_with_lda(const math::BlasT<DeviceContext, T>& blas,
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const T alpha, const T* A, const T* B,
const T beta, T* C, int lda) {
int ldb = (TransB == CblasNoTrans) ? N : K;
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
}
template <typename T>
void call_gemm_batched(const framework::ExecutionContext& ctx,
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const T alpha, const T** A, const T** B,
const T beta, T** C, const int batch) {
for (int i = 0; i < batch; ++i) {
call_gemm(ctx, TransA, TransB, M, N, K, alpha, A[i], B[i], beta, C[i]);
}
}
#ifndef TYPE_USE_FLOAT
#define TYPE_USE_FLOAT
#endif
#ifndef USE_SSE
#define USE_SSE
#endif
#if defined(TYPE_USE_FLOAT)
#define __m256x __m256
#define __m128x __m128
static const unsigned int AVX_STEP_SIZE = 8;
static const unsigned int SSE_STEP_SIZE = 4;
static const unsigned int AVX_CUT_LEN_MASK = 7U;
static const unsigned int SSE_CUT_LEN_MASK = 3U;
#define _mm256_mul_px _mm256_mul_ps
#define _mm256_add_px _mm256_add_ps
#define _mm256_load_px _mm256_loadu_ps
#define _mm256_store_px _mm256_storeu_ps
#define _mm256_broadcast_sx _mm256_broadcast_ss
#define _mm_add_px _mm_add_ps
#define _mm_mul_px _mm_mul_ps
#define _mm_load_px _mm_loadu_ps
#define _mm_store_px _mm_storeu_ps
#define _mm_load1_px _mm_load1_ps
#endif
template <typename T>
inline void sse_axpy(const T* x, T* y, size_t len, const T alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
#if defined(USE_AVX)
lll = len & ~AVX_CUT_LEN_MASK;
__m256x mm_alpha = _mm256_broadcast_sx(&alpha);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
_mm256_store_px(
y + jjj,
_mm256_add_px(_mm256_load_px(y + jjj),
_mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj))));
}
#elif defined(USE_SSE)
lll = len & ~SSE_CUT_LEN_MASK;
__m128x mm_alpha = _mm_load1_px(&alpha);
for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
_mm_store_px(y + jjj,
_mm_add_px(_mm_load_px(y + jjj),
_mm_mul_px(mm_alpha, _mm_load_px(x + jjj))));
}
#endif
for (; jjj < len; jjj++) {
y[jjj] += alpha * x[jjj];
}
}
} // namespace operators
} // namespace paddle
......@@ -211,6 +211,7 @@ __all__ = [
'deformable_conv',
'unfold',
'deformable_roi_pooling',
'match_matrix_tensor',
'filter_by_instag',
'var_conv_2d',
'shard_index',
......@@ -13092,6 +13093,88 @@ def var_conv_2d(input,
return helper.append_activation(conv_res)
def match_matrix_tensor(x,
y,
channel_num,
act=None,
param_attr=None,
dtype='float32',
name=None):
"""
Calculate the semantic matching matrix of two word sequences with variable length.
Given a query A of length `n` and a title B of length `m`, the input shape are respectively
[n, h] and [m, h], which h is hidden_size. If :attr:`channel_num` is set to 3,
it will generate a learnable parameter matrix W with shape [h, 3, h].
Then the semantic matching matrix of query A and title B is calculated by
A * W * B.T = [n, h]*[h, 3, h]*[h, m] = [n, 3, m]. The learnable parameter matrix `W`
is equivalent to a fully connected layer in the calculation process. If :attr:`act` is provided,
the corresponding activation function will be applied to output matrix.
The :attr:`x` and :attr:`y` should be LodTensor and only one level LoD is supported.
.. code-block:: text
Given a 1-level LoDTensor x:
x.lod = [[2, 3, ]]
x.data = [[0.3, 0.1], [0.2, 0.3], [0.5, 0.6], [0.7, 0.1], [0.3, 0.4]]
x.dims = [5, 2]
y is a Tensor:
y.lod = [[3, 1, ]]
y.data = [[0.1, 0.2], [0.3, 0.7], [0.9, 0.2], [0.4, 0.1]]
y.dims = [4, 2]
set channel_num 2, then we get a 1-level LoDTensor:
out.lod = [[12, 6]] # where 12 = channel_num * x.lod[0][0] * y.lod[0][0]
out.dims = [18, 1] # where 18 = 12 + 6
Args:
x (Variable): Input variable x which should be 1-level LodTensor.
y (Variable): Input variable y which should be 1-level LodTensor.
channel_num (int): The channel number of learnable parameter W.
act (str, default None): Activation to be applied to the output of this layer.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable
parameters/weights of this layer.
dtype ('float32'): The data type of w data.
name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None
Returns:
Variable: output with LoD specified by this layer.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
x_lod_tensor = layers.data(name='x', shape=[10], lod_level=1)
y_lod_tensor = layers.data(name='y', shape=[10], lod_level=1)
out, out_tmp = layers.match_matrix_tensor(x=x_lod_tensor, y=y_lod_tensor, channel_num=3)
"""
helper = LayerHelper('match_matrix_tensor', **locals())
x_shape = list(x.shape)
y_shape = list(y.shape)
assert len(x_shape) == 2 and len(y_shape) == 2 and x_shape[-1] == y_shape[
-1]
weight_shape = [x_shape[-1], channel_num, y_shape[-1]]
w = helper.create_parameter(
attr=helper.param_attr, shape=weight_shape, dtype=dtype, is_bias=False)
mm_res = helper.create_variable_for_type_inference(dtype)
tmp_res = helper.create_variable_for_type_inference(
dtype, stop_gradient=True)
helper.append_op(
type='match_matrix_tensor',
inputs={
'X': x,
'Y': y,
'W': w,
},
outputs={"Out": mm_res,
"Tmp": tmp_res},
attrs={'dim_t': channel_num})
return helper.append_activation(mm_res), tmp_res
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
"""
This layer creates the sharded index for input. This layers is used in
......
......@@ -78,10 +78,10 @@ if(NOT WITH_MKLML)
endif()
if(NOT WITH_MKL)
list(REMOVE_ITEM TEST_OPS test_match_matrix_tensor_op)
list(REMOVE_ITEM TEST_OPS test_var_conv_2d)
endif(NOT WITH_MKL)
if(WITH_GPU OR NOT WITH_MKLML)
# matmul with multiple heads need MKL support
LIST(REMOVE_ITEM TEST_OPS test_matmul_op_with_head)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
class TestMatchMatrixTensorOp(OpTest):
def setUp(self):
self.init_op_type()
self.set_data()
self.compute()
def init_op_type(self):
self.op_type = "match_matrix_tensor"
def set_data(self):
ix, iy, h, dim_t = [5, 8, 3, 4]
x_lod = [[1, 2, 2]]
y_lod = [[3, 1, 4]]
self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
def init_data(self, ix, x_lod, iy, y_lod, h, dim_t):
x_data = np.random.random((ix, h)).astype('float32')
y_data = np.random.random((iy, h)).astype('float32')
w_data = np.random.random((h, dim_t, h)).astype('float32')
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod), 'W': w_data}
self.attrs = {'dim_t': dim_t}
def compute(self):
x_data, x_lod = self.inputs['X']
y_data, y_lod = self.inputs['Y']
# [k, dim_t, k] -> [dim_t, k, k]
w_data = self.inputs['W'].transpose(1, 0, 2)
out = np.zeros((0, 1), dtype=x_data.dtype)
# for x*w
tmp = np.zeros((0, 1), dtype=x_data.dtype)
out_lod = [[]]
tmp_lod = [[]]
x_offset, y_offset = 0, 0
for idx in range(len(x_lod[0])):
x_len = x_lod[0][idx]
y_len = y_lod[0][idx]
x_sub = x_data[x_offset:(x_offset + x_len), :]
y_sub = y_data[y_offset:(y_offset + y_len), :]
tmp_sub = np.dot(x_sub, w_data)
tmp = np.vstack((tmp, tmp_sub.reshape(tmp_sub.size, 1)))
out_sub = np.dot(tmp_sub, y_sub.T).transpose(1, 0, 2)
out_lod[0].append(out_sub.size)
out = np.vstack((out, out_sub.reshape(out_sub.size, 1)))
x_offset += x_len
y_offset += y_len
self.outputs = {'Out': (out, out_lod), 'Tmp': tmp}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005)
class TestMatchMatrixTensorOpCase1(TestMatchMatrixTensorOp):
def set_data(self):
ix, iy, h, dim_t = [5, 8, 16, 4]
x_lod = [[5]]
y_lod = [[8]]
self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
class TestMatchMatrixTensorOpCase2(TestMatchMatrixTensorOp):
def set_data(self):
ix, iy, h, dim_t = [7, 8, 1, 4]
x_lod = [[2, 3, 2]]
y_lod = [[3, 1, 4]]
self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
class TestMatchMatrixTensorOpCase3(TestMatchMatrixTensorOp):
def set_data(self):
ix, iy, h, dim_t = [5, 9, 32, 1]
x_lod = [[1, 2, 2]]
y_lod = [[3, 2, 4]]
self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
class TestMatchMatrixTensorOpCase4(TestMatchMatrixTensorOp):
def set_data(self):
ix, iy, h, dim_t = [8, 12, 16, 5]
x_lod = [[1, 2, 3, 1, 1]]
y_lod = [[3, 2, 4, 1, 2]]
self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
def test_api(self):
x_lod_tensor = fluid.layers.data(name='x', shape=[10], lod_level=1)
y_lod_tensor = fluid.layers.data(name='y', shape=[10], lod_level=1)
out, out_tmp = fluid.layers.match_matrix_tensor(
x=x_lod_tensor, y=y_lod_tensor, channel_num=3)
place = fluid.CPUPlace()
x_data = np.random.rand(7, 10).astype('float32')
y_data = np.random.rand(9, 10).astype('float32')
x = fluid.create_lod_tensor(x_data, [[2, 5]], place)
y = fluid.create_lod_tensor(y_data, [[3, 6]], place)
exe = fluid.Executor(place=place)
exe.run(fluid.default_startup_program())
ret = exe.run(feed={'x': x,
'y': y},
fetch_list=[out],
return_numpy=False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册