未验证 提交 a65c728e 编写于 作者: Y Yiqun Liu 提交者: GitHub

Implement the GPU kernel of fc operator (#19687)

* Refine the codes related to fc op.

* Add GPU implementation for fc functor.

* Apply fc_fuse_pass in GPU inference.
test=develop

* Change the cmake for fc op.

* Change PADDLE_ENFORCE to PADDLE_ENFORCE_EQ.

* Add an attribute to set the activation type in fc_op.

* Enhance the unittest of fc_op.
test=develop

* Remove the declaration of FCOpGrad back to the header file.
test=develop

* Set default value for newly added arguments in test_fc_op.
test=develop
上级 22301115
develop 1.8.5 2.0.1-rocm-post 2.4.1 Ligoml-patch-1 OliverLPH-patch-1 OliverLPH-patch-2 PaddlePM-patch-1 PaddlePM-patch-2 ZHUI-patch-1 add_default_att add_kylinv10 add_model_benchmark_ci add_some_yaml_config addfile all_new_design_exec ascendrc ascendrelease bugfix-eval-frame-leakgae cherry-pick-fix-customOP-random-fail cherry_undefined_var compile_windows cp_2.4_fix_numpy delete_2.0.1-rocm-post delete_add_default_att delete_all_new_design_exec delete_ascendrc delete_compile_windows delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_dataloader_memory_leak delete_fix_imperative_dygraph_error delete_fix_retry_ci delete_fix_undefined_var delete_improve_sccache delete_paddle_tiny_install delete_paralleltest delete_prv-disable-more-cache delete_revert-31068-fix_conv3d_windows delete_revert-31562-mean delete_revert-33630-bug-fix delete_revert-34159-add_npu_bce_logical_dev delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix-run-program-grad-node-mem fix_check fix_concat_slice fix_custom_device_copy_sync fix_dataloader_memory_leak fix_dlpack_for fix_imperative_dygraph_error fix_newexe_gc fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var fix_var_stop_gradient_error fixiscan fixiscan1 fixiscan2 fixiscan3 github/fork/123malin/netifaces github/fork/123malin/tdm_abacus github/fork/AshburnLee/dev_unique github/fork/ForFishes/fix_memory_matmul github/fork/ForFishes/rm_fluid github/fork/LielinJiang/move-2.0-api github/fork/LielinJiang/visual-dl-cb github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api github/fork/LiuChiachi/fix-example-code-for-hapi-Model github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model github/fork/MrChengmo/fix_ps_profiler github/fork/MrChengmo/update_ps_heter github/fork/PWhiddy/patch-1 github/fork/Shixiaowei02/dev/save_load_upgrade github/fork/TCChenlong/fix_hapi github/fork/TCChenlong/fix_inden github/fork/Thunderbrook/xpu_slice github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2 github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3 github/fork/XieYunshen/timeout_20S_ut github/fork/ZeyuChen/remove-nltk github/fork/arlesniak/arlesniak/selective__mkldnn_flags github/fork/baiyfbupt/code_doc_mig github/fork/chalsliu/set_timeout github/fork/chen-zhiyu/develop github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads github/fork/chenwhql/saveload/add_get_inference_program github/fork/chenwhql/saveload/remove_save_load_config github/fork/cryoco/pass-compatibility-trt github/fork/danleifeng/isempty_api2.0 github/fork/frankwhzhang/api_transfer github/fork/hbwx24/error_msg/cuda_kernel_error_msg github/fork/heavengate/cherry_yolo_box github/fork/heavengate/update_yolo_box github/fork/iclementine/rnn_fix github/fork/iducn/testestse github/fork/jczaja/prv-25537-fix github/fork/jeff41404/release/1.8 github/fork/jiweibo/api_2.0 github/fork/jiweibo/fix_lite_resnet50_test github/fork/juncaipeng/fix_doc_1 github/fork/lfchener/sample_code github/fork/littletomatodonkey/fix_reg_doc github/fork/liym27/dy2stat_update_assign_to_rc20 github/fork/luotao1/profiler_ut github/fork/mapingshuo/add_wait github/fork/mapingshuo/doc_2.0 github/fork/mapingshuo/zero-0.5 github/fork/miraiwk/dev github/fork/pangyoki/add-Categorical-class-branch github/fork/pangyoki/add-multinomial-op-branch github/fork/pangyoki/fix-test_distritbution-CI github/fork/qjing666/doublegrad github/fork/qjing666/fix_hdfs_download github/fork/sandyhouse/add_gather_etc github/fork/sandyhouse/add_send_recv_alltoall_etc github/fork/sandyhouse/pipeline_exe_run github/fork/seiriosPlus/feature/large_scale_kv_save_delta github/fork/seiriosPlus/fix/paddle_errors_fix github/fork/seiriosPlus/fix/paddle_op_errors github/fork/shangzhizhou/fix_test_activation_op_random_bug github/fork/smallv0221/yxp0924 github/fork/smallv0221/yxp0925 github/fork/swtkiwi/del-matplotlib github/fork/tianshuo78520a/kunlun_test github/fork/tianshuo78520a/update_dockerfile github/fork/wanghaoshuang/bert_fuse github/fork/wanghaoshuang/label_smooth github/fork/wanghuancoder/develop_CUDASynchronize github/fork/wanghuancoder/develop_Layer_doc github/fork/wanghuancoder/develop_ParameterList_doc github/fork/wanghuancoder/develop_Sequential_doc github/fork/wanghuancoder/develop_bilinear_tensor_product github/fork/wanghuancoder/develop_coverage_build_sh github/fork/wanghuancoder/develop_in_dynamic_mode_doc github/fork/wanghuancoder/develop_unique_name_doc github/fork/wangxicoding/fleet_meta_combine github/fork/wawltor/error_message_fix_5 github/fork/willthefrog/remove_l2_norm github/fork/windstamp/momentum_op github/fork/windstamp/mv_op_5 github/fork/windstamp/normal_api github/fork/wojtuss/wojtuss/fusion_gru_quantization github/fork/wojtuss/wojtuss/quantization-with-shift github/fork/wzzju/fix_err_info github/fork/wzzju/pure_fp16 github/fork/xiemoyuan/op_error_message github/fork/xiemoyuan/optimize_error_message github/fork/yaoxuefeng6/fix_doc github/fork/yaoxuefeng6/mod_dataset_v2 github/fork/yongqiangma/lod github/fork/ysh329/fix-clip-by-norm-error github/fork/ysh329/fix-error-clip-by-value github/fork/yukavio/error_info github/fork/zhangting2020/conv_filter_grad github/fork/zhangting2020/is_compile_with_cuda github/fork/zhangting2020/place_doc github/fork/zhangting2020/program github/fork/zhhsplendid/fix_any github/fork/zhhsplendid/refine_api2 github/fork/zhhsplendid/refine_api2_test github/fork/zhhsplendid/refine_api_test_ptb_lm github/fork/zhhsplendid/refine_api_test_resnet github/fork/zhhsplendid/refine_api_test_simnet github/fork/zhiqiu/dev/refine_initializer github/fork/zhiqiu/dev/remove_inplace_argument github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11 hack_event improve_sccache incuabte/new_frl incubate/frl_train_eval incubate/infrt incubate/new_frl incubate/new_frl_rc incubate/stride inplace_addto layer_norm make_flag_adding_easier matmul_double_grad move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc numel operator_opt paddle_tiny_install paralleltest pass-compile-eval-frame preln_ernie prv-disable-more-cache prv-md-even-more prv-onednn-2.5 prv-reshape-mkldnn-ut2 pten_tensor_refactor release-deleted/2.5 release-rc/2.5 release/1.6 release/1.7 release/1.8 release/2.0 release/2.0-alpha release/2.0-beta release/2.0-rc release/2.0-rc1 release/2.1 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 release/2.5 release/llm_2.5 revert-24981-add_device_attr_for_regulization revert-26856-strategy_example2 revert-27520-disable_pr revert-31068-fix_conv3d_windows revert-31562-mean revert-32290-develop-hardlabel revert-33037-forci revert-33475-fix_cifar_label_dimension revert-33630-bug-fix revert-34159-add_npu_bce_logical_dev revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment revert-47325-remove_cudnn_hardcode revert-47645-add_npu_storage_dims revert-48815-set_free_when_no_cache_hit_default_value_true revert-49499-test_ninja_on_ci revert-49654-prim_api_gen revert-49673-modify_get_single_cov revert-49763-fix_static_composite_gen revert-50158-fix_found_inf_bug_for_custom_optimizer revert-50188-refine_optimizer_create_accumulators revert-50335-fix_optminizer_set_auxiliary_var_bug revert-51676-flag_delete revert-51850-fix_softmaxce_dev revert-52175-dev_peak_memory revert-52186-deve revert-52523-test_py38 revert-52912-develop revert-53248-set_cmake_policy revert-54029-fix_windows_compile_bug revert-54068-support_translating_op_attribute revert-54214-modify_cmake_dependencies revert-54370-offline_pslib revert-54391-fix_cmake_md5error revert-54411-fix_cpp17_compile revert-54466-offline_pslib revert-54480-cmake-rocksdb revert-55568-fix_BF16_bug1 revert-56328-new_ir_support_vector_type_place_transfer revert-56366-fix_openssl_bug revert-56545-revert-56366-fix_openssl_bug revert-56620-fix_new_ir_ocr_bug revert-56925-check_inputs_grad_semantic revert-57005-refine_stride_flag rocm_dev_0217 sd_conv_linear_autocast semi-auto/rule-base support-0D-sort support_weight_transpose test_benchmark_ci test_feature_precision_test_c test_for_Filtetfiles test_model_benchmark test_model_benchmark_ci zhiqiu-patch-1 v2.5.1 v2.5.0 v2.5.0-rc1 v2.5.0-rc0 v2.4.2 v2.4.1 v2.4.0 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0 v2.1.3 v2.1.2 v2.1.1 v2.1.0 v2.1.0-rc0 v2.0.2 v2.0.1 v2.0.0 v2.0.0-rc1 v2.0.0-rc0 v2.0.0-beta0 v2.0.0-alpha0 v1.8.5 v1.8.4 v1.8.3 v1.8.2 v1.8.1 v1.8.0 v1.7.2 v1.7.1 v1.7.0 v1.6.3 v1.6.2 v1.6.1 v1.6.0 v1.6.0-rc0
无相关合并请求
......@@ -191,9 +191,6 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
elseif(${TARGET} STREQUAL "tensorrt_engine_op")
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
elseif(${TARGET} STREQUAL "fc")
# HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif()
......
......@@ -21,7 +21,6 @@
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
......
......@@ -106,6 +106,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
// "identity_scale_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"fc_fuse_pass", //
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", //
......
......@@ -90,7 +90,7 @@ endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
......@@ -339,10 +339,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
atted_x_data, atten_b_data);
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, 1, M, x_data, atten_w_data, atted_x_data,
atten_b_data);
const T* cur_atten_x_data = atted_x_data;
const T* cur_x_data = x_data;
......@@ -369,8 +372,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
// 1d. softmax
vec_softmax<T>(seq_len, fc_out_data, fc_out_data);
// mul x(seq_len*M) and sum pool
math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data,
cur_x_data, lstm_x_data);
fc(dev_ctx, 1, M, seq_len, fc_out_data, cur_x_data, lstm_x_data);
/// 2. compute LSTM step
// lstm weight : concat[forget , input , output , tilde]
......
......@@ -14,18 +14,20 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h"
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
namespace paddle {
namespace operators {
void FCOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
class FCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"X(Input) of Fully Connected should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Out(Output) of Fully Connected should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
"W(Input) of Fully Connected should not be null.");
auto in_dims = ctx->GetInputDim("Input");
......@@ -34,7 +36,8 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
if (bias_dims.size() == 2) {
PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[0], 1,
"The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
"The shape of Bias must be [1, dim].");
} else if (bias_dims.size() == 1) {
......@@ -43,8 +46,14 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
}
}
auto& activation_type = ctx->Attrs().Get<std::string>("activation_type");
if (!activation_type.empty()) {
PADDLE_ENFORCE_EQ(activation_type, "relu",
"Activation %s is not supportetd in fc now.",
activation_type.c_str());
}
if (ctx->Attrs().Get<bool>("use_mkldnn")) {
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
PADDLE_ENFORCE_EQ(in_dims.size() == 2 || in_dims.size() == 4, true,
"Fully Connected input should be 2-D or 4-D tensor.");
}
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
......@@ -60,10 +69,11 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("Input", "Out");
}
}
framework::OpKernelType FCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn")) {
......@@ -72,7 +82,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
}
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout, library);
}
}
};
void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input");
......@@ -86,7 +97,7 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
}
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Bias")), true,
"Should have bias grad");
auto bias_dims = ctx->GetInputDim("Bias");
ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims);
......@@ -105,17 +116,24 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
ctx.GetPlace(), layout, library);
}
void FCOpMaker::Make() {
AddInput("Input", "(Tensor), The input tensor of fully connected operator.");
class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(Tensor), The input tensor of fully connected operator.");
AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
.AsDispensable();
AddOutput("Out",
"(Tensor) The output tensor of fully connected operator. ");
AddAttr<int>("in_num_col_dims",
"(int, default 1), The fc op can take tensors with more than "
"two dimensions as its inputs.")
.SetDefault(1)
.EqualGreaterThan(1);
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
AddAttr<std::string>("activation_type",
"Avctivation type used in fully connected operator.")
.SetDefault("");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
......@@ -123,43 +141,11 @@ void FCOpMaker::Make() {
"Skip calling InferShape() function in the runtime.")
.SetDefault(true);
AddComment(R"DOC(
Fully Connected Operator.
Fully Connected Operator.
The fully connected operation calculates the output based on the input, weights and bias.
The size of each dimension of the parameters checked in the infer-shape.
The fully connected operation calculates the output based on the input, weights and bias.
The size of each dimension of the parameters checked in the infer-shape.
)DOC");
}
template <typename T>
class FCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto input = ctx.Input<framework::LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<framework::LoDTensor>("Out");
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
auto w_dims = w->dims();
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());
auto out_dims = output->dims();
int M = framework::product(out_dims) / w_dims[1];
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
math::FCCompute<platform::CPUDeviceContext, T>(
blas, M, w_dims[1], w_dims[0], input_data, w_data, output_data,
bias ? bias->data<T>() : NULL);
// TODO(TJ): fuse act
}
};
......@@ -170,4 +156,6 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(fc_grad, ops::FCOpGrad);
REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel<float>, ops::FCOpKernel<double>);
REGISTER_OP_CPU_KERNEL(
fc, ops::FCOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::FCOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fc_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fc, ops::FCOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::FCOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -14,24 +14,16 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/fc.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FCOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -43,11 +35,6 @@ class FCOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override;
};
class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
inline void FCOutputSize(const framework::DDim& in_dims,
const framework::DDim& w_dims,
std::vector<int64_t>& out_dims, // NOLINT
......@@ -64,5 +51,38 @@ inline void FCOutputSize(const framework::DDim& in_dims,
out_dims.push_back(w_dims[1]);
}
template <typename DeviceContext, typename T>
class FCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::LoDTensor>("Input");
auto* w = ctx.Input<Tensor>("W");
auto* bias = ctx.Input<Tensor>("Bias");
auto* output = ctx.Output<framework::LoDTensor>("Out");
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
bool with_relu =
(ctx.Attr<std::string>("activation_type") == "relu") ? true : false;
auto w_dims = w->dims();
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());
auto out_dims = output->dims();
int M = framework::product(out_dims) / w_dims[1];
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, M, w_dims[1], w_dims[0], input_data, w_data, output_data,
bias ? bias->data<T>() : NULL, with_relu);
}
};
} // namespace operators
} // namespace paddle
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
......@@ -219,8 +219,10 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const T* wh_state_data = wh_data + D * D2;
T* hidden_out_data = hidden_out->mutable_data<T>(place);
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data,
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
bias ? bias->data<T>() : nullptr);
int xx_offset = D3;
......@@ -290,16 +292,16 @@ class FusionGRUKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
math::FCFunctor<DeviceContext, T> fc;
if (M > D3) {
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data,
fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
bias ? bias->data<T>() : nullptr);
to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_input->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, xx_data, wx_data,
batched_input_data,
fc(dev_ctx, total_T, D3, M, xx_data, wx_data, batched_input_data,
bias ? bias->data<T>() : nullptr);
}
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
......@@ -281,8 +281,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T* h_out_data = hidden_out->mutable_data<T>(place);
T* c_out_data = cell_out->mutable_data<T>(place);
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
xx_data, bias->data<T>());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>());
int xx_offset = D4;
int gate_offset = D;
......@@ -359,15 +361,14 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::FCFunctor<DeviceContext, T> fc;
if (M > D4) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
xx_data, bias->data<T>());
fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data<T>());
to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_input->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, xx_data,
wx_data, batched_input_data,
fc(dev_ctx, x_dims[0], D4, M, xx_data, wx_data, batched_input_data,
bias->data<T>());
}
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <algorithm> // for min, max
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc.h"
namespace paddle {
namespace operators {
......@@ -209,9 +209,9 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
}
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::FCCompute<DeviceContext, T>(blas, x_dims[0], w_dims[1], w_dims[0],
col_data, w_data, y_data, b_data, true);
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, x_dims[0], w_dims[1], w_dims[0], col_data, w_data, y_data,
b_data, true);
}
};
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
......@@ -165,8 +165,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
out_data, b ? b->data<T>() : NULL);
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D, M0, ref_in_data, w_data, out_data,
b ? b->data<T>() : NULL);
w_data = w_data + M0 * D;
// first write on
blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
......
......@@ -56,6 +56,7 @@ math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale)
math_library(softmax DEPS math_function jit_kernel_helper)
math_library(beam_search DEPS math_function)
math_library(fc DEPS blas)
math_library(matrix_bit_code)
......
......@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
......@@ -21,25 +20,29 @@ namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T>
inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
template <typename T>
class FCFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = NULL, bool relu = false) {
const T* B = nullptr, bool relu = false) {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
blas.MatMul(M, N, K, X, W, Y);
if (B == NULL) {
return;
}
if (relu) {
auto compute =
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
N);
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache()
.At(N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
}
} else {
auto compute =
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(
N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
......@@ -48,7 +51,11 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
compute(B, dst, dst, N);
}
}
}
}
};
template class FCFunctor<platform::CPUDeviceContext, float>;
template class FCFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, bool DoRelu>
__global__ void InplaceAddReluKernel(const T* bias, T* data, int M, int N) {
for (int i = blockIdx.x; i < M; i += gridDim.x) {
int index = i * N + threadIdx.x;
for (int j = threadIdx.x; j < N; j += blockDim.x) {
T tmp = data[index] + bias[j];
if (DoRelu) {
data[index] = (tmp > 0) ? tmp : 0;
} else {
data[index] = tmp;
}
index += blockDim.x;
}
}
}
template <typename T>
class FCFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false) {
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X, K, W, N,
static_cast<T>(0.0), Y, N);
if (B == NULL) {
return;
}
const int kThreadsPerBlock = 1024;
int max_threads = context.GetMaxPhysicalThreadCount();
int num_threads = std::min(kThreadsPerBlock, (((N + 31) >> 5) << 5));
int num_blocks = std::max(max_threads / num_threads, 1);
if (relu) {
InplaceAddReluKernel<
T, true><<<num_blocks, num_threads, 0, context.stream()>>>(B, Y, M,
N);
} else {
InplaceAddReluKernel<
T, false><<<num_blocks, num_threads, 0, context.stream()>>>(B, Y, M,
N);
}
}
};
template class FCFunctor<platform::CUDADeviceContext, float>;
template class FCFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T>
class FCFunctor {
public:
void operator()(const DeviceContext& context, const int M, const int N,
const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -17,7 +17,7 @@ import numpy as np
from op_test import OpTest
def fc_refer(matrix, with_bias):
def fc_refer(matrix, with_bias, with_relu=False):
in_n, in_c, in_h, in_w = matrix.input.shape
w_i, w_o = matrix.weights.shape
......@@ -31,22 +31,32 @@ def fc_refer(matrix, with_bias):
else:
result = np.dot(x_data, w_data)
if with_relu:
return np.maximum(result, 0)
else:
return result
class MatrixGenerate:
def __init__(self, mb, ic, oc, h, w):
def __init__(self, mb, ic, oc, h, w, bias_dims=2):
self.input = np.random.random((mb, ic, h, w)).astype("float32")
self.weights = np.random.random((ic * h * w, oc)).astype("float32")
if bias_dims == 2:
self.bias = np.random.random((1, oc)).astype("float32")
else:
self.bias = np.random.random((oc)).astype("float32")
class TestFCOp(OpTest):
def config(self):
self.with_bias = True
self.with_relu = True
self.matrix = MatrixGenerate(1, 10, 15, 3, 3, 2)
def setUp(self):
self.op_type = "fc"
self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
self.config()
self.with_bias = True
if self.with_bias:
self.inputs = {
'Input': self.matrix.input,
......@@ -56,54 +66,60 @@ class TestFCOp(OpTest):
else:
self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights}
self.attrs = {'use_mkldnn': False}
if self.with_relu:
activation_type = "relu"
else:
activation_type = ""
self.attrs = {'use_mkldnn': False, 'activation_type': activation_type}
self.outputs = {'Out': fc_refer(self.matrix, self.with_bias)}
self.outputs = {
'Out': fc_refer(self.matrix, self.with_bias, self.with_relu)
}
def test_check_output(self):
self.check_output()
class TestFCOpNoBias(TestFCOp):
def init_shapes(self, mb, ic, oc, h, w):
class TestFCOpNoBias1(TestFCOp):
def config(self):
self.with_bias = False
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
self.with_relu = False
self.matrix = MatrixGenerate(2, 8, 10, 1, 1, 2)
class TestFCOpWithBias(TestFCOp):
def init_shapes(self, mb, ic, oc, h, w):
self.with_bias = True
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
class TestFCOp1(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(2, 8, 10, 1, 1)
class TestFCOp2(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(4, 5, 6, 2, 2)
class TestFCOpNoBias2(TestFCOp):
def config(self):
self.with_bias = False
self.with_relu = False
self.matrix = MatrixGenerate(4, 5, 6, 2, 2, 1)
class TestFCOp4(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(1, 32, 64, 3, 3)
class TestFCOpNoBias4(TestFCOp):
def config(self):
self.with_bias = False
self.with_relu = False
self.matrix = MatrixGenerate(1, 32, 64, 3, 3, 1)
class TestFCOpWithBias1(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(3, 8, 10, 2, 1)
class TestFCOpWithBias1(TestFCOp):
def config(self):
self.with_bias = True
self.with_relu = False
self.matrix = MatrixGenerate(3, 8, 10, 2, 1, 2)
class TestFCOpWithBias2(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(4, 5, 6, 2, 2)
class TestFCOpWithBias2(TestFCOp):
def config(self):
self.with_bias = True
self.with_relu = True
self.matrix = MatrixGenerate(4, 5, 6, 2, 2, 1)
class TestFCOpWithBias3(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(1, 64, 32, 3, 3)
class TestFCOpWithBias3(TestFCOp):
def config(self):
self.with_bias = True
self.with_relu = True
self.matrix = MatrixGenerate(1, 64, 32, 3, 3, 1)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部