未验证 提交 c706ff20 编写于 作者: S ShenLiang 提交者: GitHub

fix conflict, test=develop (#23298)

上级 5223e2bb
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
static inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename T>
__global__ void expand_input_by_rank_kernel(
const T* input, int input_row, int input_col, T* output, int output_row,
int output_col, const int* rank_offset, int rank_offset_row,
int rank_offset_col, T* ins_rank, int max_rank) {
CUDA_KERNEL_LOOP(idx, output_row * output_col) {
int output_col_idx = idx % output_col;
int output_row_idx = idx / output_col;
int k = output_col_idx / input_col;
int faster = rank_offset[output_row_idx * rank_offset_col + 2 * k + 1] - 1;
if (output_col_idx == 0) {
ins_rank[output_row_idx] = rank_offset[output_row_idx * rank_offset_col];
}
if (rank_offset[output_row_idx * rank_offset_col] - 1 < 0 || faster < 0) {
continue;
}
int rank_input_col_idx = output_col_idx % input_col;
int index = rank_offset[output_row_idx * rank_offset_col + 2 * k + 2];
output[idx] = input[rank_input_col_idx + index * input_col];
}
}
template <typename T>
void expand_rank_attention_input(cudaStream_t stream, const T* input,
int input_row, int input_col, T* output,
int output_row, int output_col,
const int* rank_offset, int rank_offset_row,
int rank_offset_col, T* ins_rank,
int max_rank) {
expand_input_by_rank_kernel<<<GET_BLOCKS(output_row * output_col),
CUDA_NUM_THREADS, 0, stream>>>(
input, input_row, input_col, output, output_row, output_col, rank_offset,
rank_offset_row, rank_offset_col, ins_rank, max_rank);
}
template <typename T>
__global__ void expand_rank_attention_param_kernel(
const T* input, int input_row, int input_col, const int* rank_offset,
int rank_offset_row, int rank_offset_col, const T* param, int param_row,
int param_col, T* output_param, int output_param_row, int output_param_col,
int max_rank) {
CUDA_KERNEL_LOOP(idx, output_param_row * output_param_col) {
int output_col_idx = idx % output_param_col;
int output_row_idx = idx / output_param_col;
int block_matrix_row = max_rank * input_col;
int ins_idx = output_row_idx / block_matrix_row;
int start_offset = output_row_idx % block_matrix_row;
int k = start_offset / input_col;
int k_offset = start_offset % input_col;
int lower = rank_offset[ins_idx * rank_offset_col] - 1;
int faster = rank_offset[2 * k + 1 + rank_offset_col * ins_idx] - 1;
if (lower < 0 || faster < 0) {
continue;
}
int start = lower * max_rank + faster;
int ori_idx =
start * param_col * input_col + k_offset * param_col + output_col_idx;
output_param[idx] = param[ori_idx];
}
}
template <typename T>
void expand_rank_attention_param(cudaStream_t stream, const T* input,
int input_row, int input_col,
const int* rank_offset, int rank_offset_row,
int rank_offset_col, const T* param,
int param_row, int param_col, T* output_param,
int output_param_row, int output_param_col,
int max_rank) {
expand_rank_attention_param_kernel<<<GET_BLOCKS(output_param_row *
output_param_col),
CUDA_NUM_THREADS, 0, stream>>>(
input, input_row, input_col, rank_offset, rank_offset_row,
rank_offset_col, param, param_row, param_col, output_param,
output_param_row, output_param_col, max_rank);
}
template <typename T>
__global__ void merge_param_gradient_kernel(
T* expanded_grad, int expanded_grad_row, int expanded_grad_col,
T* param_grad, int param_grad_row, int param_grad_col, const T* ins_rank,
int ins_num, int max_rank, int input_col) {
CUDA_KERNEL_LOOP(tid, param_grad_row * param_grad_col) {
int param_col_idx = tid % param_grad_col;
int param_row_idx = tid / param_grad_col;
int block_matrix_row = max_rank * input_col;
int rank_idx = param_row_idx / block_matrix_row;
int rank_offset = param_row_idx % block_matrix_row;
T tmp = 0;
for (int i = 0; i < ins_num; ++i) {
if (ins_rank[i] == rank_idx + 1) {
int row = i * block_matrix_row + rank_offset;
tmp += expanded_grad[row * expanded_grad_col + param_col_idx];
}
}
param_grad[tid] = tmp;
}
}
template <typename T>
void merge_rank_attention_param_grad(cudaStream_t stream, T* expanded_grad,
int expanded_grad_row,
int expanded_grad_col, T* param_grad,
int param_grad_row, int param_grad_col,
const T* ins_rank, int ins_num,
int max_rank, int input_col) {
merge_param_gradient_kernel<<<GET_BLOCKS(param_grad_row * param_grad_col),
CUDA_NUM_THREADS, 0, stream>>>(
expanded_grad, expanded_grad_row, expanded_grad_col, param_grad,
param_grad_row, param_grad_col, ins_rank, ins_num, max_rank, input_col);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/rank_attention_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class RankAttentionOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of RankAttentionOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("RankOffset"), true,
platform::errors::InvalidArgument(
"Input(RankOffset) of RankAttentionOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("RankParam"), true,
platform::errors::InvalidArgument(
"Input(RankParam) of RankAttentionOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of RankAttentionOp should not be null."));
auto max_rank = ctx->Attrs().Get<int>("MaxRank");
auto x_dims = ctx->GetInputDim("X");
auto ins_num = x_dims[0];
auto param_dims = ctx->GetInputDim("RankParam");
auto para_col = param_dims[1];
auto rank_offset_dims = ctx->GetInputDim("RankOffset");
PADDLE_ENFORCE_EQ((rank_offset_dims[1] - 1) / 2, max_rank,
platform::errors::InvalidArgument(
"Input(RankOffset) has wrong columns."));
ctx->SetOutputDim("Out", {ins_num, para_col});
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class RankAttentionGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should not be null"));
PADDLE_ENFORCE_EQ(ctx->HasInput("RankParam"), true,
platform::errors::InvalidArgument(
"Input(RankParam) should not be null"));
PADDLE_ENFORCE_EQ(ctx->HasInput("RankOffset"), true,
platform::errors::InvalidArgument(
"Input(RankOffset) should not be null"));
ctx->SetOutputDim(framework::GradVarName("RankParam"),
ctx->GetInputDim("RankParam"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class RankAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of rank_attention_Op operator.");
AddInput("RankOffset",
"(Tensor) Input tensor of rank_attention_Op operator.");
AddInput("RankParam",
"(Tensor) Input tensor of rank_attention_Op operator.");
AddOutput("Out", "Output tensor of rank_attention_Op operator.");
AddAttr<int>("MaxRank", "(int, default 3) max rank of rank_attention_Op")
.SetDefault(3);
AddComment(R"DOC(
RankAttention Operator.
This Op can calculate rank attention between input and rank_param,
and rank_param gives the organization of data. Notice: It currently supports GPU device.
This Op exists in contrib, which means that it is not shown to the public.
)DOC");
}
};
template <typename T>
class RankAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("rank_attention_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("RankOffset", this->Input("RankOffset"));
op->SetInput("RankParam", this->Input("RankParam"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("RankParam"),
this->InputGrad("RankParam"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
RankAttentionGradOpNoNeedBufferVarsInference, "RankParam");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(rank_attention, ops::RankAttentionOp,
ops::RankAttentionOpMaker,
ops::RankAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::RankAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(rank_attention_grad, ops::RankAttentionGradOp,
ops::RankAttentionGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
rank_attention,
ops::RankAttentionKernel<paddle::platform::CPUDeviceContext, float>,
ops::RankAttentionKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cublas.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/rank_attention.cu.h"
#include "paddle/fluid/operators/rank_attention_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename DeviceContext, typename T>
class RankAttentionCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
auto *rank_offset = ctx.Input<Tensor>("RankOffset");
auto *param = ctx.Input<Tensor>("RankParam");
int max_rank = ctx.Attr<int>("MaxRank");
auto *Out = ctx.Output<Tensor>("Out");
// check dims
auto x_dims = X->dims();
auto ins_num = x_dims[0];
auto x_fea_dim = x_dims[1];
auto para_dims = param->dims();
auto para_row = para_dims[0];
auto para_col = para_dims[1];
auto rank_offset_dims = rank_offset->dims();
PADDLE_ENFORCE_EQ(
rank_offset_dims[0], ins_num,
platform::errors::InvalidArgument("Input(RankOffset) has wrong rows."));
PADDLE_ENFORCE_EQ((rank_offset_dims[1] - 1) / 2, max_rank,
platform::errors::InvalidArgument(
"Input(RankOffset) has wrong columns."));
PADDLE_ENFORCE_EQ(
max_rank * max_rank * x_fea_dim, para_row,
platform::errors::InvalidArgument("Input(RankParam) has wrong rows."));
int block_matrix_row = max_rank * x_fea_dim;
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto stream = ctx.cuda_device_context().stream();
int device_id = platform::GetCurrentDeviceId();
T *param_help_data;
auto param_help_size = ins_num * block_matrix_row * para_col * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&param_help_data),
param_help_size, device_id);
platform::GpuMemsetAsync(param_help_data, 0, param_help_size, stream);
T *input_help_data;
auto input_help_size = ins_num * block_matrix_row * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&input_help_data),
input_help_size, device_id);
platform::GpuMemsetAsync(input_help_data, 0, input_help_size, stream);
T *ins_rank_data;
auto ins_rank_size = ins_num * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&ins_rank_data),
ins_rank_size, device_id);
platform::GpuMemsetAsync(ins_rank_data, -1, ins_rank_size, stream);
Out->mutable_data<T>(ctx.GetPlace());
// initialize
auto out_eigen = framework::EigenVector<T>::Flatten(*Out);
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
out_eigen.device(place) = out_eigen.constant(static_cast<T>(0));
// get data ptr
T *out_data = Out->data<T>();
expand_rank_attention_input(
ctx.cuda_device_context().stream(), X->data<T>(), ins_num, x_fea_dim,
input_help_data, ins_num, block_matrix_row, rank_offset->data<int>(),
rank_offset_dims[0], rank_offset_dims[1], ins_rank_data, max_rank);
expand_rank_attention_param(
ctx.cuda_device_context().stream(), X->data<T>(), ins_num, x_fea_dim,
rank_offset->data<int>(), rank_offset_dims[0], rank_offset_dims[1],
param->data<T>(), para_row, para_col, param_help_data,
ins_num * block_matrix_row, para_col, max_rank);
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasNoTrans;
T alpha = 1;
T beta = 0;
int64_t strideA = block_matrix_row;
int64_t strideB = block_matrix_row * para_col;
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
blas.BatchedGEMM(transA, transB, 1, para_col, block_matrix_row, alpha,
input_help_data, param_help_data, beta, out_data, ins_num,
strideA, strideB);
platform::RecordedCudaFree(param_help_data, param_help_size, device_id);
platform::RecordedCudaFree(input_help_data, input_help_size, device_id);
platform::RecordedCudaFree(ins_rank_data, ins_rank_size, device_id);
}
};
template <typename DeviceContext, typename T>
class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
auto *rank_offset = ctx.Input<Tensor>("RankOffset");
auto *param = ctx.Input<Tensor>("RankParam");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *drank_para = ctx.Output<Tensor>(framework::GradVarName("RankParam"));
// get dim
auto x_dims = X->dims();
auto ins_num = x_dims[0];
auto x_fea_dim = x_dims[1];
auto para_dims = param->dims();
auto para_row = para_dims[0];
auto para_col = para_dims[1];
auto rank_offset_dims = rank_offset->dims();
auto max_rank = (rank_offset_dims[1] - 1) / 2;
int block_matrix_row = max_rank * x_fea_dim;
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
// initialize out grad
drank_para->mutable_data<T>(ctx.GetPlace());
auto drank_para_eigen = framework::EigenVector<T>::Flatten(*drank_para);
drank_para_eigen.device(place) =
drank_para_eigen.constant(static_cast<T>(0));
auto stream = ctx.cuda_device_context().stream();
int device_id = platform::GetCurrentDeviceId();
T *param_grad_data;
auto param_grad_size = ins_num * block_matrix_row * para_col * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&param_grad_data),
param_grad_size, device_id);
platform::GpuMemsetAsync(param_grad_data, 0, param_grad_size, stream);
T *input_help_data;
auto input_help_size = ins_num * block_matrix_row * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&input_help_data),
input_help_size, device_id);
platform::GpuMemsetAsync(input_help_data, 0, input_help_size, stream);
T *ins_rank_data;
auto ins_rank_size = ins_num * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&ins_rank_data),
ins_rank_size, device_id);
platform::GpuMemsetAsync(ins_rank_data, -1, ins_rank_size, stream);
// expand input
expand_rank_attention_input(
ctx.cuda_device_context().stream(), X->data<T>(), ins_num, x_fea_dim,
input_help_data, ins_num, block_matrix_row, rank_offset->data<int>(),
rank_offset_dims[0], rank_offset_dims[1], ins_rank_data, max_rank);
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
T alpha = 1;
T beta = 0;
// get param_grad
CBLAS_TRANSPOSE transA = CblasTrans;
CBLAS_TRANSPOSE transB = CblasNoTrans;
int64_t strideA = block_matrix_row;
int64_t strideB = para_col;
blas.BatchedGEMM(transA, transB, block_matrix_row, para_col, 1, alpha,
input_help_data, dout->data<T>(), beta, param_grad_data,
ins_num, strideA, strideB);
// merge param_grad to get drank_para
merge_rank_attention_param_grad(
ctx.cuda_device_context().stream(), param_grad_data,
ins_num * block_matrix_row, para_col, drank_para->data<T>(), para_row,
para_col, ins_rank_data, ins_num, max_rank, x_fea_dim);
platform::RecordedCudaFree(param_grad_data, param_grad_size, device_id);
platform::RecordedCudaFree(input_help_data, input_help_size, device_id);
platform::RecordedCudaFree(ins_rank_data, ins_rank_size, device_id);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using GPUCtx = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(rank_attention,
ops::RankAttentionCUDAKernel<GPUCtx, float>,
ops::RankAttentionCUDAKernel<GPUCtx, double>);
REGISTER_OP_CUDA_KERNEL(rank_attention_grad,
ops::RankAttentionGradOpCUDAKernel<GPUCtx, float>,
ops::RankAttentionGradOpCUDAKernel<GPUCtx, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class RankAttentionKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::Unimplemented(
"Rank Attention only supports GPU now."));
}
};
} // namespace operators
} // namespace paddle
......@@ -33,7 +33,7 @@ __all__ = [
'fused_elemwise_activation', 'sequence_topk_avg_pooling', 'var_conv_2d',
'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool',
'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat',
'partial_sum', 'tdm_child'
'partial_sum', 'tdm_child', 'rank_attention'
]
......@@ -1017,3 +1017,65 @@ def tdm_child(x, node_nums, child_nums, param_attr=None, dtype='int32'):
'dtype': c_dtype},
stop_gradient=True)
return (child, leaf_mask)
def rank_attention(input,
rank_offset,
rank_param_shape,
rank_param_attr,
max_rank=3):
"""
**Rank Attention layer**
This Op can calculate rank attention between input and rank_param, and
rank_param gives the organization of data. Notice: It currently supports
GPU device.
This Op exists in contrib, which means that it is not shown to the public.
Args:
input: Tensor with data type float32, float64.
rank_offset: Tensor with data type int32.
rank_para_shape: The shape of rank_param.
rank_param_attr: Attribute initializer of rank_param.
max_rank: The max rank of input's ranks.
Returns:
Variable: A Tensor with the same data type as input's.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="input", shape=[None, 2], dtype="float32")
rank_offset = fluid.data(name="rank_offset", shape=[None, 7], dtype="int32")
out = fluid.contrib.layers.rank_attention(input=input,
rank_offset=rank_offset,
rank_param_shape=[18,3],
rank_param_attr=
fluid.ParamAttr(learning_rate=1.0,
name="ubm_rank_param.w_0",
initializer=
fluid.initializer.Xavier(uniform=False)),
max_rank=3)
"""
helper = LayerHelper('rank_attention', **locals())
dtype = helper.input_dtype(input_param_name='input')
input_shape = input.shape
assert input_shape[1] * max_rank * max_rank == rank_param_shape[0]
rank_param = helper.create_parameter(
attr=rank_param_attr, shape=rank_param_shape, dtype=dtype)
rank_param.stop_gradient = False
output = helper.create_variable_for_type_inference(dtype)
ins_rank = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
helper.append_op(
type="rank_attention",
inputs={
"X": input,
"RankOffset": rank_offset,
"RankParam": rank_param
},
outputs={"Out": output},
attrs={"MaxRank": max_rank})
return output
......@@ -52,6 +52,7 @@ endif()
if (NOT ${WITH_GPU})
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
LIST(REMOVE_ITEM TEST_OPS test_rank_attention_op) # TODO(shenliang03): rank_attention_op support CPU device in future
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mnist) # TODO(Yancey1989): parallel dygraph support CPU device in future
elseif(${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
......
......@@ -3030,6 +3030,22 @@ class TestBook(LayerTest):
[x, y], start_index=0, length=2)
return (sum)
def test_rank_attention(self):
with self.static_graph():
input = fluid.data(name="input", shape=[None, 2], dtype="float32")
rank_offset = fluid.data(
name="rank_offset", shape=[None, 7], dtype="int32")
out = fluid.contrib.layers.rank_attention(
input=input,
rank_offset=rank_offset,
rank_param_shape=[18, 3],
rank_param_attr=fluid.ParamAttr(
learning_rate=1.0,
name="ubm_rank_param.w_0",
initializer=fluid.initializer.Xavier(uniform=False)),
max_rank=3)
return (out)
def test_roi_pool(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import random
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from op_test import OpTest, skip_check_grad_ci
import paddle.fluid.core as core
def gen_input_help(input, rank_offset, max_rank):
input_row, input_col = input.shape
input_help = np.zeros((input_row * max_rank * input_col, ))
ins_rank = np.zeros((input_row, 1))
ins_rank.fill(-1)
output_col = max_rank * input_col
output_row = input_row
for idx in range(output_col * output_row):
output_col_idx = idx % output_col
output_row_idx = int(idx / output_col)
k = int(output_col_idx / input_col)
faster = rank_offset[output_row_idx, 2 * k + 1] - 1
if output_col_idx == 0:
ins_rank[output_row_idx] = rank_offset[output_row_idx, 0]
if rank_offset[output_row_idx, 0] - 1 < 0 or faster < 0:
continue
rank_input_col_idx = output_col_idx % input_col
index = rank_offset[output_row_idx, 2 * k + 2]
input_help[idx] = input[index, rank_input_col_idx]
input_help = input_help.reshape([input_row, max_rank * input_col])
return input_help, ins_rank
def gen_param_help(input, rank_offset, param, max_rank):
input_row, input_col = input.shape
rank_offset_row, rank_offset_col = rank_offset.shape
param_row, param_col = param.shape
block_matrix_row = input_col * max_rank
output_param_row = block_matrix_row * input_row
output_param_col = param_col
output_param = np.zeros((output_param_row * output_param_col, ))
for idx in range(output_param_row * output_param_col):
output_col_idx = idx % output_param_col
output_row_idx = int(idx / output_param_col)
ins_idx = int(output_row_idx / block_matrix_row)
start_offset = output_row_idx % block_matrix_row
k = int(start_offset / input_col)
k_offset = start_offset % input_col
lower = rank_offset[ins_idx, 0] - 1
faster = rank_offset[ins_idx, 2 * k + 1] - 1
if lower < 0 or faster < 0:
continue
start = lower * max_rank + faster
ori_idx = start * param_col * input_col + k_offset * param_col + output_col_idx
output_param[idx] = param[int(ori_idx / param_col), ori_idx % param_col]
output_param = output_param.reshape([output_param_row, output_param_col])
return output_param
def np_rank_attention(input, rank_offset, rank_para, max_rank):
input_row, input_col = input.shape
rank_offset_row, rank_offset_col = rank_offset.shape
rank_para_row, rank_para_col = rank_para.shape
assert (input_row == rank_offset_row)
assert (max_rank == ((rank_offset_col - 1) / 2))
assert (rank_para_row == max_rank * max_rank * input_col)
input_help, ins_rank = gen_input_help(input, rank_offset, max_rank)
param_help = gen_param_help(input, rank_offset, rank_para, max_rank)
block_matrix_row = input_col * max_rank
res = np.zeros((input_row, rank_para_col))
for ins in range(input_row):
res[ins, :] = \
np.dot(input_help[ins, :],
param_help[int(block_matrix_row * ins):int(block_matrix_row * (ins+1)),:])
return res, input_help, param_help, ins_rank
def gen_rank_offset(pv_nums, max_rank):
all_ins_num = 0
pv_rank_msg = []
for _ in range(pv_nums):
ins_pv = np.random.randint(1, max_rank + 2) # 1~4
rank_list = list(range(1, ins_pv + 1))
random.shuffle(rank_list)
all_ins_num = all_ins_num + ins_pv
pv_rank_msg.append(rank_list)
rank_offset = np.zeros((all_ins_num, max_rank * 2 + 1)).astype("int32")
rank_offset.fill(-1)
index = 0
for pv_number in range(len(pv_rank_msg)):
pv_ins = pv_rank_msg[pv_number]
ad_num = len(pv_ins)
index_start = index
for j in range(ad_num):
rank = -1
if pv_ins[j] <= max_rank:
rank = pv_ins[j]
rank_offset[index, 0] = rank
if rank > 0:
for k in range(ad_num):
fast_rank = -1
if pv_ins[k] <= max_rank:
fast_rank = pv_ins[k]
if fast_rank > 0:
m = fast_rank - 1
rank_offset[index, 2 * m + 1] = pv_ins[k]
rank_offset[index, 2 * m + 2] = index_start + k
index = index + 1
return all_ins_num, rank_offset
class TestRankAttentionOpComplex(OpTest):
def config(self):
self.pv_num = 100
self.x_feat = 10
self.y_feat = 15
self.max_rank = 3
self.dtype = "float64"
def setUp(self):
self.op_type = "rank_attention"
self.config()
ins_num, rank_offset = gen_rank_offset(self.pv_num, self.max_rank)
input = np.random.random((ins_num, self.x_feat)).astype(self.dtype)
rank_para_shape = [
self.max_rank * self.max_rank * self.x_feat, self.y_feat
]
rank_para = np.random.random(rank_para_shape).astype(self.dtype)
np_out, np_input_help, np_param_help, np_ins_rank = np_rank_attention(
input, np.array(rank_offset), rank_para, self.max_rank)
self.inputs = {
"X": input,
"RankOffset": np.array(rank_offset).astype("int32"),
"RankParam": rank_para
}
self.attrs = {'MaxRank': self.max_rank}
self.outputs = {"Out": np_out}
def test_check_output_gpu(self):
if core.is_compiled_with_cuda():
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad_gpu(self):
if core.is_compiled_with_cuda():
self.check_grad_with_place(core.CUDAPlace(0), ["RankParam"], "Out")
class TestRankAttentionOpCpu(OpTest):
def config(self):
self.pv_num = 100
self.x_feat = 10
self.y_feat = 15
self.max_rank = 3
self.dtype = "float64"
def setUp(self):
self.op_type = "rank_attention"
self.config()
ins_num, rank_offset = gen_rank_offset(self.pv_num, self.max_rank)
input = np.random.random((ins_num, self.x_feat)).astype(self.dtype)
rank_para_shape = [
self.max_rank * self.max_rank * self.x_feat, self.y_feat
]
rank_para = np.random.random(rank_para_shape).astype(self.dtype)
np_out, np_input_help, np_param_help, np_ins_rank = np_rank_attention(
input, np.array(rank_offset), rank_para, self.max_rank)
self.inputs = {
"X": input,
"RankOffset": np.array(rank_offset).astype("int32"),
"RankParam": rank_para
}
self.attrs = {'MaxRank': self.max_rank}
self.outputs = {"Out": np_out}
def test_check_output_cpu(self):
try:
self.check_output_with_place(place=core.CPUPlace())
except:
print("do not support cpu test, skip")
def test_check_grad_cpu(self):
try:
self.check_grad_with_place(core.CPUPlace(), ["RankParam"], "Out")
except:
print("do not support cpu test, skip")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册