未验证 提交 a65c728e 编写于 作者: Y Yiqun Liu 提交者: GitHub

Implement the GPU kernel of fc operator (#19687)

* Refine the codes related to fc op.

* Add GPU implementation for fc functor.

* Apply fc_fuse_pass in GPU inference.
test=develop

* Change the cmake for fc op.

* Change PADDLE_ENFORCE to PADDLE_ENFORCE_EQ.

* Add an attribute to set the activation type in fc_op.

* Enhance the unittest of fc_op.
test=develop

* Remove the declaration of FCOpGrad back to the header file.
test=develop

* Set default value for newly added arguments in test_fc_op.
test=develop
上级 22301115
......@@ -191,9 +191,6 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
elseif(${TARGET} STREQUAL "tensorrt_engine_op")
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
elseif(${TARGET} STREQUAL "fc")
# HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif()
......
......@@ -21,7 +21,6 @@
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
......
......@@ -103,9 +103,10 @@ const std::vector<std::string> kAnakinSubgraphPasses({
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
// "identity_scale_op_clean_pass", //
// "identity_scale_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"fc_fuse_pass", //
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", //
......
......@@ -90,7 +90,7 @@ endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
......@@ -339,10 +339,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
atted_x_data, atten_b_data);
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, 1, M, x_data, atten_w_data, atted_x_data,
atten_b_data);
const T* cur_atten_x_data = atted_x_data;
const T* cur_x_data = x_data;
......@@ -369,8 +372,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
// 1d. softmax
vec_softmax<T>(seq_len, fc_out_data, fc_out_data);
// mul x(seq_len*M) and sum pool
math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data,
cur_x_data, lstm_x_data);
fc(dev_ctx, 1, M, seq_len, fc_out_data, cur_x_data, lstm_x_data);
/// 2. compute LSTM step
// lstm weight : concat[forget , input , output , tilde]
......
......@@ -14,65 +14,76 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h"
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
namespace paddle {
namespace operators {
void FCOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"X(Input) of Fully Connected should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Fully Connected should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"W(Input) of Fully Connected should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("W");
class FCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"X(Input) of Fully Connected should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Out(Output) of Fully Connected should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
"W(Input) of Fully Connected should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("W");
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
if (bias_dims.size() == 2) {
PADDLE_ENFORCE_EQ(bias_dims[0], 1,
"The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
"The shape of Bias must be [1, dim].");
} else if (bias_dims.size() == 1) {
PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1],
"The shape of Bias must be [1, dim].");
}
}
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
if (bias_dims.size() == 2) {
PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
"The shape of Bias must be [1, dim].");
} else if (bias_dims.size() == 1) {
PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1],
"The shape of Bias must be [1, dim].");
auto& activation_type = ctx->Attrs().Get<std::string>("activation_type");
if (!activation_type.empty()) {
PADDLE_ENFORCE_EQ(activation_type, "relu",
"Activation %s is not supportetd in fc now.",
activation_type.c_str());
}
}
if (ctx->Attrs().Get<bool>("use_mkldnn")) {
PADDLE_ENFORCE_EQ(in_dims.size() == 2 || in_dims.size() == 4, true,
"Fully Connected input should be 2-D or 4-D tensor.");
}
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"Fully Connected input should be 2-D tensor.");
int in_num_col_dims = ctx->Attrs().Get<int>("in_num_col_dims");
PADDLE_ENFORCE_GT(
in_dims.size(), in_num_col_dims,
"The input tensor Input's rank of FCOp should be larger than "
"in_num_col_dims.");
std::vector<int64_t> output_dims;
FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims);
if (ctx->Attrs().Get<bool>("use_mkldnn")) {
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor.");
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("Input", "Out");
}
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"Fully Connected input should be 2-D tensor.");
int in_num_col_dims = ctx->Attrs().Get<int>("in_num_col_dims");
PADDLE_ENFORCE_GT(
in_dims.size(), in_num_col_dims,
"The input tensor Input's rank of FCOp should be larger than "
"in_num_col_dims.");
std::vector<int64_t> output_dims;
FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("Input", "Out");
}
framework::OpKernelType FCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout, library);
}
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout, library);
}
};
void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input");
......@@ -86,8 +97,8 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
}
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
"Should have bias grad");
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Bias")), true,
"Should have bias grad");
auto bias_dims = ctx->GetInputDim("Bias");
ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims);
}
......@@ -105,61 +116,36 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
ctx.GetPlace(), layout, library);
}
void FCOpMaker::Make() {
AddInput("Input", "(Tensor), The input tensor of fully connected operator.");
AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
.AsDispensable();
AddAttr<int>("in_num_col_dims",
"(int, default 1), The fc op can take tensors with more than "
"two dimensions as its inputs.")
.SetDefault(1)
.EqualGreaterThan(1);
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
"Skip calling InferShape() function in the runtime.")
.SetDefault(true);
AddComment(R"DOC(
Fully Connected Operator.
The fully connected operation calculates the output based on the input, weights and bias.
The size of each dimension of the parameters checked in the infer-shape.
)DOC");
}
template <typename T>
class FCOpKernel : public framework::OpKernel<T> {
class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto input = ctx.Input<framework::LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<framework::LoDTensor>("Out");
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
auto w_dims = w->dims();
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());
auto out_dims = output->dims();
int M = framework::product(out_dims) / w_dims[1];
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
math::FCCompute<platform::CPUDeviceContext, T>(
blas, M, w_dims[1], w_dims[0], input_data, w_data, output_data,
bias ? bias->data<T>() : NULL);
// TODO(TJ): fuse act
void Make() override {
AddInput("Input",
"(Tensor), The input tensor of fully connected operator.");
AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
.AsDispensable();
AddOutput("Out",
"(Tensor) The output tensor of fully connected operator. ");
AddAttr<int>("in_num_col_dims",
"(int, default 1), The fc op can take tensors with more than "
"two dimensions as its inputs.")
.SetDefault(1)
.EqualGreaterThan(1);
AddAttr<std::string>("activation_type",
"Avctivation type used in fully connected operator.")
.SetDefault("");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
"Skip calling InferShape() function in the runtime.")
.SetDefault(true);
AddComment(R"DOC(
Fully Connected Operator.
The fully connected operation calculates the output based on the input, weights and bias.
The size of each dimension of the parameters checked in the infer-shape.
)DOC");
}
};
......@@ -170,4 +156,6 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(fc_grad, ops::FCOpGrad);
REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel<float>, ops::FCOpKernel<double>);
REGISTER_OP_CPU_KERNEL(
fc, ops::FCOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::FCOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fc_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fc, ops::FCOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::FCOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -14,24 +14,16 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/fc.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FCOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -43,11 +35,6 @@ class FCOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override;
};
class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
inline void FCOutputSize(const framework::DDim& in_dims,
const framework::DDim& w_dims,
std::vector<int64_t>& out_dims, // NOLINT
......@@ -64,5 +51,38 @@ inline void FCOutputSize(const framework::DDim& in_dims,
out_dims.push_back(w_dims[1]);
}
template <typename DeviceContext, typename T>
class FCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::LoDTensor>("Input");
auto* w = ctx.Input<Tensor>("W");
auto* bias = ctx.Input<Tensor>("Bias");
auto* output = ctx.Output<framework::LoDTensor>("Out");
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
bool with_relu =
(ctx.Attr<std::string>("activation_type") == "relu") ? true : false;
auto w_dims = w->dims();
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());
auto out_dims = output->dims();
int M = framework::product(out_dims) / w_dims[1];
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, M, w_dims[1], w_dims[0], input_data, w_data, output_data,
bias ? bias->data<T>() : NULL, with_relu);
}
};
} // namespace operators
} // namespace paddle
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
......@@ -219,9 +219,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const T* wh_state_data = wh_data + D * D2;
T* hidden_out_data = hidden_out->mutable_data<T>(place);
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data,
bias ? bias->data<T>() : nullptr);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
bias ? bias->data<T>() : nullptr);
int xx_offset = D3;
int gate_offset = D;
......@@ -290,17 +292,17 @@ class FusionGRUKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
math::FCFunctor<DeviceContext, T> fc;
if (M > D3) {
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data,
bias ? bias->data<T>() : nullptr);
fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
bias ? bias->data<T>() : nullptr);
to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_input->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, xx_data, wx_data,
batched_input_data,
bias ? bias->data<T>() : nullptr);
fc(dev_ctx, total_T, D3, M, xx_data, wx_data, batched_input_data,
bias ? bias->data<T>() : nullptr);
}
auto batched_lod = batched_input->lod();
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
......@@ -281,8 +281,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* h_out_data = hidden_out->mutable_data<T>(place);
T* c_out_data = cell_out->mutable_data<T>(place);
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
xx_data, bias->data<T>());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>());
int xx_offset = D4;
int gate_offset = D;
......@@ -359,16 +361,15 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::FCFunctor<DeviceContext, T> fc;
if (M > D4) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
xx_data, bias->data<T>());
fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data<T>());
to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_input->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, xx_data,
wx_data, batched_input_data,
bias->data<T>());
fc(dev_ctx, x_dims[0], D4, M, xx_data, wx_data, batched_input_data,
bias->data<T>());
}
auto batched_lod = batched_input->lod();
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <algorithm> // for min, max
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc.h"
namespace paddle {
namespace operators {
......@@ -209,9 +209,9 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
}
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::FCCompute<DeviceContext, T>(blas, x_dims[0], w_dims[1], w_dims[0],
col_data, w_data, y_data, b_data, true);
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, x_dims[0], w_dims[1], w_dims[0], col_data, w_data, y_data,
b_data, true);
}
};
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
......@@ -165,8 +165,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
out_data, b ? b->data<T>() : NULL);
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D, M0, ref_in_data, w_data, out_data,
b ? b->data<T>() : NULL);
w_data = w_data + M0 * D;
// first write on
blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
......
......@@ -56,6 +56,7 @@ math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale)
math_library(softmax DEPS math_function jit_kernel_helper)
math_library(beam_search DEPS math_function)
math_library(fc DEPS blas)
math_library(matrix_bit_code)
......
......@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
......@@ -21,34 +20,42 @@ namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T>
inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = NULL, bool relu = false) {
blas.MatMul(M, N, K, X, W, Y);
if (B == NULL) {
return;
}
if (relu) {
auto compute =
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
template <typename T>
class FCFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false) {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
blas.MatMul(M, N, K, X, W, Y);
if (B == NULL) {
return;
}
} else {
auto compute =
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
if (relu) {
auto compute =
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache()
.At(N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
}
} else {
auto compute =
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(
N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
}
}
}
}
};
template class FCFunctor<platform::CPUDeviceContext, float>;
template class FCFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, bool DoRelu>
__global__ void InplaceAddReluKernel(const T* bias, T* data, int M, int N) {
for (int i = blockIdx.x; i < M; i += gridDim.x) {
int index = i * N + threadIdx.x;
for (int j = threadIdx.x; j < N; j += blockDim.x) {
T tmp = data[index] + bias[j];
if (DoRelu) {
data[index] = (tmp > 0) ? tmp : 0;
} else {
data[index] = tmp;
}
index += blockDim.x;
}
}
}
template <typename T>
class FCFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false) {
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X, K, W, N,
static_cast<T>(0.0), Y, N);
if (B == NULL) {
return;
}
const int kThreadsPerBlock = 1024;
int max_threads = context.GetMaxPhysicalThreadCount();
int num_threads = std::min(kThreadsPerBlock, (((N + 31) >> 5) << 5));
int num_blocks = std::max(max_threads / num_threads, 1);
if (relu) {
InplaceAddReluKernel<
T, true><<<num_blocks, num_threads, 0, context.stream()>>>(B, Y, M,
N);
} else {
InplaceAddReluKernel<
T, false><<<num_blocks, num_threads, 0, context.stream()>>>(B, Y, M,
N);
}
}
};
template class FCFunctor<platform::CUDADeviceContext, float>;
template class FCFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T>
class FCFunctor {
public:
void operator()(const DeviceContext& context, const int M, const int N,
const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -17,7 +17,7 @@ import numpy as np
from op_test import OpTest
def fc_refer(matrix, with_bias):
def fc_refer(matrix, with_bias, with_relu=False):
in_n, in_c, in_h, in_w = matrix.input.shape
w_i, w_o = matrix.weights.shape
......@@ -31,22 +31,32 @@ def fc_refer(matrix, with_bias):
else:
result = np.dot(x_data, w_data)
return result
if with_relu:
return np.maximum(result, 0)
else:
return result
class MatrixGenerate:
def __init__(self, mb, ic, oc, h, w):
def __init__(self, mb, ic, oc, h, w, bias_dims=2):
self.input = np.random.random((mb, ic, h, w)).astype("float32")
self.weights = np.random.random((ic * h * w, oc)).astype("float32")
self.bias = np.random.random((1, oc)).astype("float32")
if bias_dims == 2:
self.bias = np.random.random((1, oc)).astype("float32")
else:
self.bias = np.random.random((oc)).astype("float32")
class TestFCOp(OpTest):
def config(self):
self.with_bias = True
self.with_relu = True
self.matrix = MatrixGenerate(1, 10, 15, 3, 3, 2)
def setUp(self):
self.op_type = "fc"
self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
self.config()
self.with_bias = True
if self.with_bias:
self.inputs = {
'Input': self.matrix.input,
......@@ -56,54 +66,60 @@ class TestFCOp(OpTest):
else:
self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights}
self.attrs = {'use_mkldnn': False}
if self.with_relu:
activation_type = "relu"
else:
activation_type = ""
self.attrs = {'use_mkldnn': False, 'activation_type': activation_type}
self.outputs = {'Out': fc_refer(self.matrix, self.with_bias)}
self.outputs = {
'Out': fc_refer(self.matrix, self.with_bias, self.with_relu)
}
def test_check_output(self):
self.check_output()
class TestFCOpNoBias(TestFCOp):
def init_shapes(self, mb, ic, oc, h, w):
class TestFCOpNoBias1(TestFCOp):
def config(self):
self.with_bias = False
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
class TestFCOpWithBias(TestFCOp):
def init_shapes(self, mb, ic, oc, h, w):
self.with_bias = True
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
self.with_relu = False
self.matrix = MatrixGenerate(2, 8, 10, 1, 1, 2)
class TestFCOp1(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(2, 8, 10, 1, 1)
class TestFCOp2(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(4, 5, 6, 2, 2)
class TestFCOpNoBias2(TestFCOp):
def config(self):
self.with_bias = False
self.with_relu = False
self.matrix = MatrixGenerate(4, 5, 6, 2, 2, 1)
class TestFCOp4(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(1, 32, 64, 3, 3)
class TestFCOpNoBias4(TestFCOp):
def config(self):
self.with_bias = False
self.with_relu = False
self.matrix = MatrixGenerate(1, 32, 64, 3, 3, 1)
class TestFCOpWithBias1(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(3, 8, 10, 2, 1)
class TestFCOpWithBias1(TestFCOp):
def config(self):
self.with_bias = True
self.with_relu = False
self.matrix = MatrixGenerate(3, 8, 10, 2, 1, 2)
class TestFCOpWithBias2(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(4, 5, 6, 2, 2)
class TestFCOpWithBias2(TestFCOp):
def config(self):
self.with_bias = True
self.with_relu = True
self.matrix = MatrixGenerate(4, 5, 6, 2, 2, 1)
class TestFCOpWithBias3(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(1, 64, 32, 3, 3)
class TestFCOpWithBias3(TestFCOp):
def config(self):
self.with_bias = True
self.with_relu = True
self.matrix = MatrixGenerate(1, 64, 32, 3, 3, 1)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册