From 0fb9b208ab285cd240abe6df3c788e207368c4cf Mon Sep 17 00:00:00 2001 From: ShenLiang <2282912238@qq.com> Date: Sun, 26 Apr 2020 21:08:46 +0800 Subject: [PATCH] Add batch_fc op in contrib (#24017) * add batch fc op, test=develop * add batch_fc_op, test=develop * fix untest, test=develop * rm check_dygraph, test=develop * fix comment, test=develop * fix comment, test=develop --- paddle/fluid/operators/batch_fc_op.cc | 155 ++++++++++++++ paddle/fluid/operators/batch_fc_op.cu | 198 ++++++++++++++++++ paddle/fluid/operators/batch_fc_op.h | 32 +++ python/paddle/fluid/contrib/layers/nn.py | 65 +++++- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_batch_fc_op.py | 106 ++++++++++ .../fluid/tests/unittests/test_layers.py | 18 ++ 7 files changed, 574 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/batch_fc_op.cc create mode 100644 paddle/fluid/operators/batch_fc_op.cu create mode 100644 paddle/fluid/operators/batch_fc_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_batch_fc_op.py diff --git a/paddle/fluid/operators/batch_fc_op.cc b/paddle/fluid/operators/batch_fc_op.cc new file mode 100644 index 00000000000..004ebe0eb81 --- /dev/null +++ b/paddle/fluid/operators/batch_fc_op.cc @@ -0,0 +1,155 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/batch_fc_op.h" +#include + +namespace paddle { +namespace operators { + +class BatchFCOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Input"), true, + platform::errors::InvalidArgument( + "X(Input) of Batch Fully Connected should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Out(Output) of Batch Fully Connected should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("W"), true, + platform::errors::InvalidArgument( + "W(Input) of Batch Fully Connected should not be null.")); + + auto input_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + + PADDLE_ENFORCE_EQ(input_dims.size(), 3, + platform::errors::InvalidArgument( + "Input of BatchFCOp should have 3D.")); + PADDLE_ENFORCE_EQ(w_dims.size(), 3, platform::errors::InvalidArgument( + "W of BatchFCOp should have 3D.")); + PADDLE_ENFORCE_EQ( + input_dims[0], w_dims[0], + platform::errors::InvalidArgument( + "Input.dim[0] and W.dim[0] of BatchFCOp should be same.")); + PADDLE_ENFORCE_EQ( + input_dims[2], w_dims[1], + platform::errors::InvalidArgument( + "Input.dim[2] and W.dim[1] of BatchFCOp should be same.")); + + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(bias_dims[0], input_dims[0], + platform::errors::InvalidArgument( + "Bias.dim[0] should be same as input.dim[0].")); + PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[2], + platform::errors::InvalidArgument( + "Bias.dim[1] should be same as input.dim[2].")); + + ctx->SetOutputDim("Out", {input_dims[0], input_dims[1], w_dims[2]}); + ctx->ShareLoD("Input", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); + } +}; + +class BatchFCGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Input"), true, + platform::errors::InvalidArgument("Input should not be null")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("W"), true, + platform::errors::InvalidArgument("Input(W) should not be null")); + + ctx->SetOutputDim(framework::GradVarName("Input"), + ctx->GetInputDim("Input")); + ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); + ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +class BatchFCOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "(Tensor) Input tensor of batch_fc_op operator."); + AddInput("W", "(Tensor) Input tensor of batch_fc_op operator."); + AddInput("Bias", "(Tensor) Input tensor of batch_fc_op operator."); + AddOutput("Out", "Output tensor of batch_fc_op operator."); + AddComment(R"DOC( +BatchFC Operator. +Notice: It currently supports GPU device. +This Op exists in contrib, which means that it is not shown to the public. +)DOC"); + } +}; + +template +class BatchFCGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("batch_fc_grad"); + + op->SetInput("Input", this->Input("Input")); + op->SetInput("W", this->Input("W")); + op->SetInput("Bias", this->Input("Bias")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + + op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); + op->SetOutput(framework::GradVarName("W"), this->InputGrad("W")); + op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); + op->SetAttrMap(this->Attrs()); + } +}; +DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchFCGradOpNoNeedBufferVarsInference, + "Bias"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(batch_fc, ops::BatchFCOp, ops::BatchFCOpMaker, + ops::BatchFCGradOpMaker, + ops::BatchFCGradOpMaker); + +REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp, + ops::BatchFCGradOpNoNeedBufferVarsInference); + +REGISTER_OP_CPU_KERNEL( + batch_fc, ops::BatchFCKernel, + ops::BatchFCKernel); diff --git a/paddle/fluid/operators/batch_fc_op.cu b/paddle/fluid/operators/batch_fc_op.cu new file mode 100644 index 00000000000..414eeef2a6f --- /dev/null +++ b/paddle/fluid/operators/batch_fc_op.cu @@ -0,0 +1,198 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/batch_fc_op.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { +using framework::Tensor; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +static inline int GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__global__ void add_bias_kernel(T* data, int slot_pairs_num, int ins_num, + int out_dim, const T* bias) { + CUDA_KERNEL_LOOP(idx, slot_pairs_num * ins_num * out_dim) { + int block_len = ins_num * out_dim; + int slot_index = idx / block_len; + int out_dim_index = (idx % block_len) % out_dim; + T temp = data[idx] + bias[slot_index * out_dim + out_dim_index]; + data[idx] = temp; + } +} + +template +void add_bias(cudaStream_t stream, T* data, int slot_pairs_num, int ins_num, + int out_dim, const T* bias) { + add_bias_kernel<<>>(data, slot_pairs_num, + ins_num, out_dim, bias); +} + +template +__global__ void add_bias_grad_kernel(const T* dout_data, int slot_pairs_num, + int ins_num, int out_dim, T* db_data) { + CUDA_KERNEL_LOOP(idx, slot_pairs_num * out_dim) { + int row = idx / out_dim; + int col = idx % out_dim; + T temp = static_cast(0); + for (int i = 0; i < ins_num; ++i) { + int select_indx = ((row + 1) * i + 1) * col; + temp += dout_data[select_indx]; + } + db_data[idx] += temp; + } +} + +template +void add_bias_grad(cudaStream_t stream, const T* dout_data, int slot_pairs_num, + int ins_num, int out_dim, T* db_data) { + add_bias_grad_kernel<<>>(dout_data, slot_pairs_num, ins_num, + out_dim, db_data); +} + +template +class BatchFCCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // X.dim = slot_pairs_num * ins_num * in_dim + // W.dim = slot_pairs_num * in_dim * out_dim + // b.dim = slot_pairs_num * out_dim + // output.dim = slot_pairs_num * ins_num * out_dim + auto* input = ctx.Input("Input"); + auto* w = ctx.Input("W"); + auto* bias = ctx.Input("Bias"); + auto* output = ctx.Output("Out"); + auto input_dims = input->dims(); + auto w_dims = w->dims(); + auto slot_pairs_num = input_dims[0]; + auto ins_num = input_dims[1]; + auto in_dim = input_dims[2]; + auto out_dim = w_dims[2]; + + // get data ptr + const T* in_data = input->data(); + const T* w_data = w->data(); + const T* bias_data = bias->data(); + + output->Resize({slot_pairs_num, ins_num, out_dim}); + T* out_data = output->mutable_data(ctx.GetPlace()); + // initialize + auto out_eigen = framework::EigenVector::Flatten(*output); + auto& dev_ctx = ctx.template device_context(); + auto& place = *ctx.template device_context() + .eigen_device(); + out_eigen.device(place) = out_eigen.constant(static_cast(0)); + + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasNoTrans; + + T alpha = 1; + T beta = 0; + int64_t strideA = ins_num * in_dim; + int64_t strideB = in_dim * out_dim; + + auto blas = math::GetBlas(dev_ctx); + blas.BatchedGEMM(transA, transB, ins_num, out_dim, in_dim, alpha, in_data, + w_data, beta, out_data, slot_pairs_num, strideA, strideB); + add_bias(ctx.cuda_device_context().stream(), out_data, slot_pairs_num, + ins_num, out_dim, bias_data); + } +}; + +template +class BatchFCGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* w = ctx.Input("W"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + auto* dx = ctx.Output(framework::GradVarName("Input")); + auto* dw = ctx.Output(framework::GradVarName("W")); + auto* db = ctx.Output(framework::GradVarName("Bias")); + + auto input_dims = input->dims(); + auto w_dims = w->dims(); + auto slot_pairs_num = input_dims[0]; + auto ins_num = input_dims[1]; + auto in_dim = input_dims[2]; + auto out_dim = w_dims[2]; + + auto& dev_ctx = ctx.template device_context(); + auto& place = *ctx.template device_context() + .eigen_device(); + // initialize + dx->mutable_data(ctx.GetPlace()); + auto dx_eigen = framework::EigenVector::Flatten(*dx); + dx_eigen.device(place) = dx_eigen.constant(static_cast(0)); + + dw->mutable_data(ctx.GetPlace()); + auto dw_eigen = framework::EigenVector::Flatten(*dw); + dw_eigen.device(place) = dw_eigen.constant(static_cast(0)); + + // get data ptr + const T* x_data = input->data(); + const T* w_data = w->data(); + const T* dout_data = dout->data(); + T* dx_data = dx->data(); + T* dw_data = dw->data(); + + db->mutable_data(ctx.GetPlace()); + auto db_eigen = framework::EigenVector::Flatten(*db); + db_eigen.device(place) = db_eigen.constant(static_cast(0)); + T* db_data = db->data(); + add_bias_grad(ctx.cuda_device_context().stream(), dout_data, + slot_pairs_num, ins_num, out_dim, db_data); + + auto blas = math::GetBlas(dev_ctx); + T alpha = 1; + T beta = 0; + + // dx = dout_data * y^T + blas.BatchedGEMM(CblasNoTrans, CblasTrans, ins_num, in_dim, out_dim, alpha, + dout_data, w_data, beta, dx_data, slot_pairs_num, + ins_num * out_dim, out_dim * in_dim); + // dy = x^T * dout_data + blas.BatchedGEMM(CblasTrans, CblasNoTrans, in_dim, out_dim, ins_num, alpha, + x_data, dout_data, beta, dw_data, slot_pairs_num, + in_dim * ins_num, ins_num * out_dim); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using GPUCtx = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(batch_fc, ops::BatchFCCUDAKernel, + ops::BatchFCCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(batch_fc_grad, + ops::BatchFCGradOpCUDAKernel, + ops::BatchFCGradOpCUDAKernel); diff --git a/paddle/fluid/operators/batch_fc_op.h b/paddle/fluid/operators/batch_fc_op.h new file mode 100644 index 00000000000..a3b8a6942ba --- /dev/null +++ b/paddle/fluid/operators/batch_fc_op.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class BatchFCKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::Unimplemented("BatchFC only supports GPU now.")); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 3b6372c000c..3bb56edb9b7 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -34,7 +34,7 @@ __all__ = [ 'fused_elemwise_activation', 'sequence_topk_avg_pooling', 'var_conv_2d', 'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool', 'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat', - 'partial_sum', 'tdm_child', 'rank_attention', 'tdm_sampler' + 'partial_sum', 'tdm_child', 'rank_attention', 'tdm_sampler', 'batch_fc' ] @@ -1298,3 +1298,66 @@ def rank_attention(input, attrs={"MaxRank": max_rank, "MaxSize": max_size}) return output + + +def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None): + """ + **Batch FC layer** + This Op can calculate BatchFC. This is similar to matmul op, + except that the bias and relu activation layers are added. + Notice: It currently supports GPU device. + This Op exists in contrib, which means that it is not shown to the public. + Args: + input: Tensor with data type float32, float64. + param_size: The size of w. + param_attr: Attribute initializer of w. + bias_size: The size of bias. + bias_attr: Attribute initializer of bias. + act: Activation to be applied to the output of this layer. + + Returns: + Variable: A Tensor with the same data type as input's. + Examples: + .. code-block:: python + import paddle.fluid as fluid + + input = fluid.data(name="input", shape=[16, 2, 3], dtype="float32") + out = fluid.contrib.layers.batch_fc(input=input, + param_size=[16, 3, 10], + param_attr= + fluid.ParamAttr(learning_rate=1.0, + name="w_0", + initializer= + fluid.initializer.Xavier(uniform=False)), + bias_size=[16, 10], + bias_attr= + fluid.ParamAttr(learning_rate=1.0, + name="b_0", + initializer= + fluid.initializer.Xavier(uniform=False)), + act="relu") + """ + + helper = LayerHelper("batch_fc", **locals()) + check_type(input, 'input', (Variable), 'batch_fc') + input_shape = input.shape + assert input_shape[0] == param_size[0] + assert input_shape[2] == param_size[1] + assert param_size[2] == bias_size[1] + assert input_shape[0] == bias_size[0] + + dtype = helper.input_dtype() + check_dtype(dtype, 'input', ['float32', 'float64'], 'batch_fc') + + w = helper.create_parameter( + attr=param_attr, shape=param_size, dtype=dtype, is_bias=False) + b = helper.create_parameter( + attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=False) + pre_act = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="batch_fc", + inputs={"Input": input, + "W": w, + "Bias": b}, + outputs={"Out": pre_act}) + return helper.append_activation(pre_act) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 0919cbe188c..1f8e6d31316 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -54,6 +54,7 @@ endif() if (NOT ${WITH_GPU}) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) LIST(REMOVE_ITEM TEST_OPS test_rank_attention_op) # TODO(shenliang03): rank_attention_op support CPU device in future + LIST(REMOVE_ITEM TEST_OPS test_batch_fc_op) # TODO(shenliang03): batch_fc_op support CPU device in future LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mnist) # TODO(Yancey1989): parallel dygraph support CPU device in future elseif(${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) diff --git a/python/paddle/fluid/tests/unittests/test_batch_fc_op.py b/python/paddle/fluid/tests/unittests/test_batch_fc_op.py new file mode 100644 index 00000000000..56631d8d3b4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_batch_fc_op.py @@ -0,0 +1,106 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import random +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid.core as core + + +def np_cal_batchfc(input, w, bias): + slot_pairs_num, batch_size, in_dim = input.shape + _, _, out_dim = w.shape + res = np.zeros((slot_pairs_num, batch_size, out_dim)) + for slot in range(slot_pairs_num): + res[slot, :] = np.dot(input[slot, :], w[slot, :]) + for slot in range(slot_pairs_num): + for bindx in range(out_dim): + res[slot, :, bindx] += bias[slot, bindx] + return res + + +class TestBatchFCOp(OpTest): + def config(self): + self.slot_pairs_num = 10 + self.batch_size = 5 + self.in_dim = 10 + self.out_dim = 12 + self.dtype = "float64" + + def setUp(self): + self.config() + self.input = np.random.random((self.slot_pairs_num, self.batch_size, + self.in_dim)).astype(self.dtype) + self.w = np.random.random((self.slot_pairs_num, self.in_dim, + self.out_dim)).astype(self.dtype) + self.bias = np.random.random((self.slot_pairs_num, + self.out_dim)).astype(self.dtype) + self.op_type = "batch_fc" + np_out = np_cal_batchfc(self.input, self.w, self.bias) + np_out = np_out.astype(self.dtype) + self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias} + self.outputs = {"Out": np_out} + + def test_check_output_gpu(self): + if core.is_compiled_with_cuda(): + self.check_output_with_place(core.CUDAPlace(0)) + + def test_check_grad_gpu(self): + if core.is_compiled_with_cuda(): + self.check_grad_with_place( + core.CUDAPlace(0), ["Bias", "W", "Input"], "Out") + + +class TestBatchFCOp1(OpTest): + def config(self): + self.slot_pairs_num = 10 + self.batch_size = 5 + self.in_dim = 10 + self.out_dim = 12 + self.dtype = "float64" + + def setUp(self): + self.config() + self.input = np.random.random((self.slot_pairs_num, self.batch_size, + self.in_dim)).astype(self.dtype) + self.w = np.random.random((self.slot_pairs_num, self.in_dim, + self.out_dim)).astype(self.dtype) + self.bias = np.random.random((self.slot_pairs_num, + self.out_dim)).astype(self.dtype) + self.op_type = "batch_fc" + np_out = np_cal_batchfc(self.input, self.w, self.bias) + np_out = np_out.astype(self.dtype) + self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias} + self.outputs = {"Out": np_out} + + def test_check_output_cpu(self): + try: + self.check_output_with_place(place=core.CPUPlace()) + except: + print("do not support cpu test, skip") + + def test_check_grad_cpu(self): + try: + self.check_grad_with_place(core.CPUPlace(), ["Bias", "W", "Input"], + "Out") + except: + print("do not support cpu test, skip") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 1df1f34e761..db1dba071fa 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -3195,6 +3195,24 @@ class TestBook(LayerTest): [x, y], start_index=0, length=2) return (sum) + def test_batch_fc(self): + with self.static_graph(): + input = fluid.data(name="input", shape=[16, 2, 3], dtype="float32") + out = fluid.contrib.layers.batch_fc( + input=input, + param_size=[16, 3, 10], + param_attr=fluid.ParamAttr( + learning_rate=1.0, + name="w_0", + initializer=fluid.initializer.Xavier(uniform=False)), + bias_size=[16, 10], + bias_attr=fluid.ParamAttr( + learning_rate=1.0, + name="b_0", + initializer=fluid.initializer.Xavier(uniform=False)), + act="relu") + return (out) + def test_rank_attention(self): with self.static_graph(): input = fluid.data(name="input", shape=[None, 2], dtype="float32") -- GitLab