未验证 提交 0fb9b208 编写于 作者: S ShenLiang 提交者: GitHub

Add batch_fc op in contrib (#24017)

* add batch fc op, test=develop

* add batch_fc_op, test=develop

* fix untest, test=develop

* rm check_dygraph, test=develop

* fix comment, test=develop

* fix comment, test=develop
上级 f5c08c3f
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_fc_op.h"
#include <string>
namespace paddle {
namespace operators {
class BatchFCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"), true,
platform::errors::InvalidArgument(
"X(Input) of Batch Fully Connected should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Out(Output) of Batch Fully Connected should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("W"), true,
platform::errors::InvalidArgument(
"W(Input) of Batch Fully Connected should not be null."));
auto input_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(input_dims.size(), 3,
platform::errors::InvalidArgument(
"Input of BatchFCOp should have 3D."));
PADDLE_ENFORCE_EQ(w_dims.size(), 3, platform::errors::InvalidArgument(
"W of BatchFCOp should have 3D."));
PADDLE_ENFORCE_EQ(
input_dims[0], w_dims[0],
platform::errors::InvalidArgument(
"Input.dim[0] and W.dim[0] of BatchFCOp should be same."));
PADDLE_ENFORCE_EQ(
input_dims[2], w_dims[1],
platform::errors::InvalidArgument(
"Input.dim[2] and W.dim[1] of BatchFCOp should be same."));
auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims[0], input_dims[0],
platform::errors::InvalidArgument(
"Bias.dim[0] should be same as input.dim[0]."));
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[2],
platform::errors::InvalidArgument(
"Bias.dim[1] should be same as input.dim[2]."));
ctx->SetOutputDim("Out", {input_dims[0], input_dims[1], w_dims[2]});
ctx->ShareLoD("Input", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
class BatchFCGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"), true,
platform::errors::InvalidArgument("Input should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("W"), true,
platform::errors::InvalidArgument("Input(W) should not be null"));
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class BatchFCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "(Tensor) Input tensor of batch_fc_op operator.");
AddInput("W", "(Tensor) Input tensor of batch_fc_op operator.");
AddInput("Bias", "(Tensor) Input tensor of batch_fc_op operator.");
AddOutput("Out", "Output tensor of batch_fc_op operator.");
AddComment(R"DOC(
BatchFC Operator.
Notice: It currently supports GPU device.
This Op exists in contrib, which means that it is not shown to the public.
)DOC");
}
};
template <typename T>
class BatchFCGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("batch_fc_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("W", this->Input("W"));
op->SetInput("Bias", this->Input("Bias"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchFCGradOpNoNeedBufferVarsInference,
"Bias");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(batch_fc, ops::BatchFCOp, ops::BatchFCOpMaker,
ops::BatchFCGradOpMaker<paddle::framework::OpDesc>,
ops::BatchFCGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp,
ops::BatchFCGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
batch_fc, ops::BatchFCKernel<paddle::platform::CPUDeviceContext, float>,
ops::BatchFCKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cublas.h>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/batch_fc_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
using framework::Tensor;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
static inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename T>
__global__ void add_bias_kernel(T* data, int slot_pairs_num, int ins_num,
int out_dim, const T* bias) {
CUDA_KERNEL_LOOP(idx, slot_pairs_num * ins_num * out_dim) {
int block_len = ins_num * out_dim;
int slot_index = idx / block_len;
int out_dim_index = (idx % block_len) % out_dim;
T temp = data[idx] + bias[slot_index * out_dim + out_dim_index];
data[idx] = temp;
}
}
template <typename T>
void add_bias(cudaStream_t stream, T* data, int slot_pairs_num, int ins_num,
int out_dim, const T* bias) {
add_bias_kernel<<<GET_BLOCKS(slot_pairs_num * ins_num * out_dim),
CUDA_NUM_THREADS, 0, stream>>>(data, slot_pairs_num,
ins_num, out_dim, bias);
}
template <typename T>
__global__ void add_bias_grad_kernel(const T* dout_data, int slot_pairs_num,
int ins_num, int out_dim, T* db_data) {
CUDA_KERNEL_LOOP(idx, slot_pairs_num * out_dim) {
int row = idx / out_dim;
int col = idx % out_dim;
T temp = static_cast<T>(0);
for (int i = 0; i < ins_num; ++i) {
int select_indx = ((row + 1) * i + 1) * col;
temp += dout_data[select_indx];
}
db_data[idx] += temp;
}
}
template <typename T>
void add_bias_grad(cudaStream_t stream, const T* dout_data, int slot_pairs_num,
int ins_num, int out_dim, T* db_data) {
add_bias_grad_kernel<<<GET_BLOCKS(slot_pairs_num * out_dim), CUDA_NUM_THREADS,
0, stream>>>(dout_data, slot_pairs_num, ins_num,
out_dim, db_data);
}
template <typename DeviceContext, typename T>
class BatchFCCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// X.dim = slot_pairs_num * ins_num * in_dim
// W.dim = slot_pairs_num * in_dim * out_dim
// b.dim = slot_pairs_num * out_dim
// output.dim = slot_pairs_num * ins_num * out_dim
auto* input = ctx.Input<framework::LoDTensor>("Input");
auto* w = ctx.Input<Tensor>("W");
auto* bias = ctx.Input<Tensor>("Bias");
auto* output = ctx.Output<framework::LoDTensor>("Out");
auto input_dims = input->dims();
auto w_dims = w->dims();
auto slot_pairs_num = input_dims[0];
auto ins_num = input_dims[1];
auto in_dim = input_dims[2];
auto out_dim = w_dims[2];
// get data ptr
const T* in_data = input->data<T>();
const T* w_data = w->data<T>();
const T* bias_data = bias->data<T>();
output->Resize({slot_pairs_num, ins_num, out_dim});
T* out_data = output->mutable_data<T>(ctx.GetPlace());
// initialize
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
out_eigen.device(place) = out_eigen.constant(static_cast<T>(0));
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasNoTrans;
T alpha = 1;
T beta = 0;
int64_t strideA = ins_num * in_dim;
int64_t strideB = in_dim * out_dim;
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
blas.BatchedGEMM(transA, transB, ins_num, out_dim, in_dim, alpha, in_data,
w_data, beta, out_data, slot_pairs_num, strideA, strideB);
add_bias<T>(ctx.cuda_device_context().stream(), out_data, slot_pairs_num,
ins_num, out_dim, bias_data);
}
};
template <typename DeviceContext, typename T>
class BatchFCGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* w = ctx.Input<Tensor>("W");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto* dw = ctx.Output<Tensor>(framework::GradVarName("W"));
auto* db = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto input_dims = input->dims();
auto w_dims = w->dims();
auto slot_pairs_num = input_dims[0];
auto ins_num = input_dims[1];
auto in_dim = input_dims[2];
auto out_dim = w_dims[2];
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
// initialize
dx->mutable_data<T>(ctx.GetPlace());
auto dx_eigen = framework::EigenVector<T>::Flatten(*dx);
dx_eigen.device(place) = dx_eigen.constant(static_cast<T>(0));
dw->mutable_data<T>(ctx.GetPlace());
auto dw_eigen = framework::EigenVector<T>::Flatten(*dw);
dw_eigen.device(place) = dw_eigen.constant(static_cast<T>(0));
// get data ptr
const T* x_data = input->data<T>();
const T* w_data = w->data<T>();
const T* dout_data = dout->data<T>();
T* dx_data = dx->data<T>();
T* dw_data = dw->data<T>();
db->mutable_data<T>(ctx.GetPlace());
auto db_eigen = framework::EigenVector<T>::Flatten(*db);
db_eigen.device(place) = db_eigen.constant(static_cast<T>(0));
T* db_data = db->data<T>();
add_bias_grad<T>(ctx.cuda_device_context().stream(), dout_data,
slot_pairs_num, ins_num, out_dim, db_data);
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
T alpha = 1;
T beta = 0;
// dx = dout_data * y^T
blas.BatchedGEMM(CblasNoTrans, CblasTrans, ins_num, in_dim, out_dim, alpha,
dout_data, w_data, beta, dx_data, slot_pairs_num,
ins_num * out_dim, out_dim * in_dim);
// dy = x^T * dout_data
blas.BatchedGEMM(CblasTrans, CblasNoTrans, in_dim, out_dim, ins_num, alpha,
x_data, dout_data, beta, dw_data, slot_pairs_num,
in_dim * ins_num, ins_num * out_dim);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using GPUCtx = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(batch_fc, ops::BatchFCCUDAKernel<GPUCtx, float>,
ops::BatchFCCUDAKernel<GPUCtx, double>);
REGISTER_OP_CUDA_KERNEL(batch_fc_grad,
ops::BatchFCGradOpCUDAKernel<GPUCtx, float>,
ops::BatchFCGradOpCUDAKernel<GPUCtx, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class BatchFCKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::Unimplemented("BatchFC only supports GPU now."));
}
};
} // namespace operators
} // namespace paddle
......@@ -34,7 +34,7 @@ __all__ = [
'fused_elemwise_activation', 'sequence_topk_avg_pooling', 'var_conv_2d',
'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool',
'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat',
'partial_sum', 'tdm_child', 'rank_attention', 'tdm_sampler'
'partial_sum', 'tdm_child', 'rank_attention', 'tdm_sampler', 'batch_fc'
]
......@@ -1298,3 +1298,66 @@ def rank_attention(input,
attrs={"MaxRank": max_rank,
"MaxSize": max_size})
return output
def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None):
"""
**Batch FC layer**
This Op can calculate BatchFC. This is similar to matmul op,
except that the bias and relu activation layers are added.
Notice: It currently supports GPU device.
This Op exists in contrib, which means that it is not shown to the public.
Args:
input: Tensor with data type float32, float64.
param_size: The size of w.
param_attr: Attribute initializer of w.
bias_size: The size of bias.
bias_attr: Attribute initializer of bias.
act: Activation to be applied to the output of this layer.
Returns:
Variable: A Tensor with the same data type as input's.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.data(name="input", shape=[16, 2, 3], dtype="float32")
out = fluid.contrib.layers.batch_fc(input=input,
param_size=[16, 3, 10],
param_attr=
fluid.ParamAttr(learning_rate=1.0,
name="w_0",
initializer=
fluid.initializer.Xavier(uniform=False)),
bias_size=[16, 10],
bias_attr=
fluid.ParamAttr(learning_rate=1.0,
name="b_0",
initializer=
fluid.initializer.Xavier(uniform=False)),
act="relu")
"""
helper = LayerHelper("batch_fc", **locals())
check_type(input, 'input', (Variable), 'batch_fc')
input_shape = input.shape
assert input_shape[0] == param_size[0]
assert input_shape[2] == param_size[1]
assert param_size[2] == bias_size[1]
assert input_shape[0] == bias_size[0]
dtype = helper.input_dtype()
check_dtype(dtype, 'input', ['float32', 'float64'], 'batch_fc')
w = helper.create_parameter(
attr=param_attr, shape=param_size, dtype=dtype, is_bias=False)
b = helper.create_parameter(
attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=False)
pre_act = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="batch_fc",
inputs={"Input": input,
"W": w,
"Bias": b},
outputs={"Out": pre_act})
return helper.append_activation(pre_act)
......@@ -54,6 +54,7 @@ endif()
if (NOT ${WITH_GPU})
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
LIST(REMOVE_ITEM TEST_OPS test_rank_attention_op) # TODO(shenliang03): rank_attention_op support CPU device in future
LIST(REMOVE_ITEM TEST_OPS test_batch_fc_op) # TODO(shenliang03): batch_fc_op support CPU device in future
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mnist) # TODO(Yancey1989): parallel dygraph support CPU device in future
elseif(${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import random
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from op_test import OpTest, skip_check_grad_ci
import paddle.fluid.core as core
def np_cal_batchfc(input, w, bias):
slot_pairs_num, batch_size, in_dim = input.shape
_, _, out_dim = w.shape
res = np.zeros((slot_pairs_num, batch_size, out_dim))
for slot in range(slot_pairs_num):
res[slot, :] = np.dot(input[slot, :], w[slot, :])
for slot in range(slot_pairs_num):
for bindx in range(out_dim):
res[slot, :, bindx] += bias[slot, bindx]
return res
class TestBatchFCOp(OpTest):
def config(self):
self.slot_pairs_num = 10
self.batch_size = 5
self.in_dim = 10
self.out_dim = 12
self.dtype = "float64"
def setUp(self):
self.config()
self.input = np.random.random((self.slot_pairs_num, self.batch_size,
self.in_dim)).astype(self.dtype)
self.w = np.random.random((self.slot_pairs_num, self.in_dim,
self.out_dim)).astype(self.dtype)
self.bias = np.random.random((self.slot_pairs_num,
self.out_dim)).astype(self.dtype)
self.op_type = "batch_fc"
np_out = np_cal_batchfc(self.input, self.w, self.bias)
np_out = np_out.astype(self.dtype)
self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias}
self.outputs = {"Out": np_out}
def test_check_output_gpu(self):
if core.is_compiled_with_cuda():
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad_gpu(self):
if core.is_compiled_with_cuda():
self.check_grad_with_place(
core.CUDAPlace(0), ["Bias", "W", "Input"], "Out")
class TestBatchFCOp1(OpTest):
def config(self):
self.slot_pairs_num = 10
self.batch_size = 5
self.in_dim = 10
self.out_dim = 12
self.dtype = "float64"
def setUp(self):
self.config()
self.input = np.random.random((self.slot_pairs_num, self.batch_size,
self.in_dim)).astype(self.dtype)
self.w = np.random.random((self.slot_pairs_num, self.in_dim,
self.out_dim)).astype(self.dtype)
self.bias = np.random.random((self.slot_pairs_num,
self.out_dim)).astype(self.dtype)
self.op_type = "batch_fc"
np_out = np_cal_batchfc(self.input, self.w, self.bias)
np_out = np_out.astype(self.dtype)
self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias}
self.outputs = {"Out": np_out}
def test_check_output_cpu(self):
try:
self.check_output_with_place(place=core.CPUPlace())
except:
print("do not support cpu test, skip")
def test_check_grad_cpu(self):
try:
self.check_grad_with_place(core.CPUPlace(), ["Bias", "W", "Input"],
"Out")
except:
print("do not support cpu test, skip")
if __name__ == "__main__":
unittest.main()
......@@ -3195,6 +3195,24 @@ class TestBook(LayerTest):
[x, y], start_index=0, length=2)
return (sum)
def test_batch_fc(self):
with self.static_graph():
input = fluid.data(name="input", shape=[16, 2, 3], dtype="float32")
out = fluid.contrib.layers.batch_fc(
input=input,
param_size=[16, 3, 10],
param_attr=fluid.ParamAttr(
learning_rate=1.0,
name="w_0",
initializer=fluid.initializer.Xavier(uniform=False)),
bias_size=[16, 10],
bias_attr=fluid.ParamAttr(
learning_rate=1.0,
name="b_0",
initializer=fluid.initializer.Xavier(uniform=False)),
act="relu")
return (out)
def test_rank_attention(self):
with self.static_graph():
input = fluid.data(name="input", shape=[None, 2], dtype="float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册