提交 3df38f5c 编写于 作者: G GaoWei8 提交者: Yiqun Liu

[cherry-pick] Add FC padding, ernie test unit and layernorm parallel (#22198)

* Optimize the kernel implementation of layernorm with openmp (#20895)

* Add ernie c++ inference test (#21015)

* Add ernie unit test
test=develop

* Add ernie unit test
test=develop

* Add ernie unit test
test=develop

* remove ngraph

* optimize gpu test
test=develop

* optimize codes
test=develop

* fix cmake fails on inference_download_and_uncompress (#21185)

* solve cmake fails on inference_download_and_uncompress
test=develop

* solve cmake fails on inference_download_and_uncompress
test=develop

* Add fc padding to improve mkl GEMM's performance when N and K are multiple of 128. (#20972)

* Add fc padding to solve mkl performance
test=develop

* fix gpu pass and error information
test=develop

* fix fc_fuse_pass_test
test=develop

* fix error information
test=develop

* fix error information
test=develop

* fix name and add fc op padding test
test=develop

* fix attributes
test=develop

* optimize fc padding
test=develop

* fix test
test=develop

* Polish the codes of fc when needs padding (#21378)

test=develop

* Add ernie large c++ inference test (#21365)

* add ernie-large test
test=develop

* add ernie large c++ inference test
test=develop

* Modify padding strategy: remove weight copy in fc padding (#21650)

test=develop

* optimize fc jit (#21878)

test=develop
Co-authored-by: NYihua Xu <yihuaxu@hotmail.com>
上级 e8e12499
......@@ -89,6 +89,35 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
std::string activation_type = with_relu ? "relu" : "";
desc.SetAttr("activation_type", activation_type);
// This is to add padding for dimension 128 on concern of MKL performance
auto* scope = param_scope();
auto* weight = scope->FindVar(w->Name())->GetMutable<LoDTensor>();
auto place = weight->place();
bool use_gpu = Get<bool>("use_gpu");
auto* weight_data = weight->data<float>();
auto weight_dims = weight->dims();
int weight_num = product(weight_dims);
int w_h = weight_dims[0];
int w_w = weight_dims[1];
if (!use_gpu) {
if (w_h % 128 == 0 && w_w % 128 == 0) {
auto* weight_data_tmp = new float[weight_num];
for (int i = 0; i < w_h; i++) {
memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w,
w_w * sizeof(float));
}
weight->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4});
auto* weight_data_new =
weight->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < w_h; i++) {
memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w,
w_w * sizeof(float));
}
delete[] weight_data_tmp;
desc.SetAttr("padding_weights", true);
}
}
// For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass
......
......@@ -21,6 +21,24 @@ namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "conv2d_filters_0", {});
AddVarToScope(param_scope, "conv2d_bias_0", {});
AddVarToScope(param_scope, "weights_0", {});
AddVarToScope(param_scope, "weights_1", {});
AddVarToScope(param_scope, "bias_1", {});
AddVarToScope(param_scope, "bias_2", {});
return param_scope;
}
TEST(FCFusePass, basic) {
// inputs operator output
// --------------------------------------------------------
......@@ -50,6 +68,8 @@ TEST(FCFusePass, basic) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("fc_fuse_pass");
pass->Set("use_gpu", new bool(true));
graph->Set("__param_scope__", CreateParamScope());
int num_nodes_before = graph->Nodes().size();
int num_mul_nodes_before = GetNumOpNodes(graph, "mul");
VLOG(3) << DebugString(graph);
......
......@@ -27,10 +27,14 @@ function(download_model_and_data install_dir model_name data_name)
download_data(${install_dir} ${data_name})
endfunction()
function(download_result install_dir result_name)
download_data(${install_dir} ${result_name})
endfunction()
function(inference_analysis_api_test target install_dir filename)
inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark
ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt)
ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt --refer_result=${install_dir}/result.txt)
endfunction()
function(inference_analysis_api_test_build TARGET_NAME filename)
......@@ -72,13 +76,6 @@ function(inference_analysis_api_test_with_fake_data_run TARGET_NAME test_binary
--disable_mkldnn_fc=${disable_fc})
endfunction()
function(inference_analysis_api_test_with_refer_result target install_dir filename)
inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt
--refer_result=${install_dir}/result.txt)
endfunction()
function(inference_analysis_api_qat_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path)
inference_analysis_test_run(${TARGET_NAME}
COMMAND ${test_binary}
......@@ -147,6 +144,20 @@ set(PYRAMID_DNN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/pyramid_dnn")
download_model_and_data(${PYRAMID_DNN_INSTALL_DIR} "PyramidDNN_model.tar.gz" "PyramidDNN_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_pyramid_dnn ${PYRAMID_DNN_INSTALL_DIR} analyzer_pyramid_dnn_tester.cc)
#Ernie
set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie")
download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_model.tar.gz" "Ernie_data.txt.tar.gz" "Ernie_result.txt.tar.gz")
download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz")
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} analyzer_ernie_tester.cc)
#Ernie large
set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_Large")
download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_large_model.tar.gz" "Ernie_large_data.txt.tar.gz" "Ernie_large_result.txt.tar.gz")
download_result(${ERNIE_INSTALL_DIR} "Ernie_large_result.txt.tar.gz")
inference_analysis_test(test_analyzer_ernie_large SRCS analyzer_ernie_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark
ARGS --infer_model=${ERNIE_INSTALL_DIR}/model --infer_data=${ERNIE_INSTALL_DIR}/data.txt --refer_result=${ERNIE_INSTALL_DIR}/result.txt --ernie_large=true)
# text_classification
set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classification")
download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz")
......@@ -170,14 +181,14 @@ set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR})
inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Focr.tar.gz")
endif()
inference_analysis_api_test_with_refer_result(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc)
inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc)
# mobilenet with transpose op
set(MOBILENET_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet")
if (NOT EXISTS ${MOBILENET_INSTALL_DIR})
inference_download_and_uncompress(${MOBILENET_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Fmobilenet.tar.gz")
endif()
inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc)
inference_analysis_api_test(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc)
### Image classification tests with fake data
set(IMG_CLASS_TEST_APP "test_analyzer_image_classification")
......@@ -334,13 +345,9 @@ inference_analysis_test(test_analyzer_capi SRCS analyzer_capi_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${RESNET50_MODEL_DIR}/model)
set(CAPI_MODEL_INSTALL_PD_DIR "${INFERENCE_DEMO_INSTALL_DIR}/capi_mobilenet")
if (NOT EXISTS ${CAPI_MODEL_INSTALL_PD_DIR})
inference_download_and_uncompress(${CAPI_MODEL_INSTALL_PD_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Fmobilenet.tar.gz")
endif()
inference_analysis_test(test_analyzer_capi_pd_tensor SRCS analyzer_capi_pd_tensor_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${CAPI_MODEL_INSTALL_PD_DIR}/model)
ARGS --infer_model=${MOBILENET_INSTALL_DIR}/model)
if(WITH_MKLDNN)
inference_analysis_test(test_analyzer_capi_int SRCS analyzer_capi_int_tester.cc
......
......@@ -153,7 +153,6 @@ void profile(bool use_mkldnn = false, bool use_ngraph = false) {
if (use_mkldnn) {
config.EnableMKLDNN();
config.pass_builder()->AppendPass("fc_mkldnn_pass");
}
if (use_ngraph) {
......@@ -193,7 +192,6 @@ void compare(bool use_mkldnn = false, bool use_ngraph = false) {
SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
}
if (use_ngraph) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
using paddle::PaddleTensor;
template <typename T>
void GetValueFromStream(std::stringstream *ss, T *t) {
(*ss) >> (*t);
}
template <>
void GetValueFromStream<std::string>(std::stringstream *ss, std::string *t) {
*t = ss->str();
}
// Split string to vector
template <typename T>
void Split(const std::string &line, char sep, std::vector<T> *v) {
std::stringstream ss;
T t;
for (auto c : line) {
if (c != sep) {
ss << c;
} else {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
if (!ss.str().empty()) {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
// Parse tensor from string
template <typename T>
bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) {
std::vector<std::string> data;
Split(field, ':', &data);
if (data.size() < 2) return false;
std::string shape_str = data[0];
std::vector<int> shape;
Split(shape_str, ' ', &shape);
std::string mat_str = data[1];
std::vector<T> mat;
Split(mat_str, ' ', &mat);
tensor->shape = shape;
auto size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
sizeof(T);
tensor->data.Resize(size);
std::copy(mat.begin(), mat.end(), static_cast<T *>(tensor->data.data()));
tensor->dtype = GetPaddleDType<T>();
return true;
}
// Parse input tensors from string
bool ParseLine(const std::string &line,
std::vector<paddle::PaddleTensor> *tensors) {
std::vector<std::string> fields;
Split(line, ';', &fields);
tensors->clear();
tensors->reserve(4);
int i = 0;
auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_";
for (; i < 3; i++) {
paddle::PaddleTensor temp;
ParseTensor<int64_t>(fields[i], &temp);
temp.name = input_name + std::to_string(i);
tensors->push_back(temp);
}
// input_mask
paddle::PaddleTensor input_mask;
ParseTensor<float>(fields[i], &input_mask);
input_mask.name = input_name + std::to_string(i);
tensors->push_back(input_mask);
return true;
}
bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs) {
if (FLAGS_infer_data.empty()) {
LOG(ERROR) << "please set input data path";
return false;
}
std::ifstream fin(FLAGS_infer_data);
std::string line;
int sample = 0;
// The unit-test dataset only have 10 samples, each sample have 5 feeds.
while (std::getline(fin, line)) {
std::vector<paddle::PaddleTensor> feed_data;
ParseLine(line, &feed_data);
inputs->push_back(std::move(feed_data));
sample++;
if (!FLAGS_test_all_data && sample == FLAGS_batch_size) break;
}
LOG(INFO) << "number of samples: " << sample;
return true;
}
void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false,
bool use_gpu = false) {
cfg->SetModel(FLAGS_infer_model);
if (use_mkldnn) {
cfg->EnableMKLDNN();
}
if (use_gpu) {
cfg->EnableUseGpu(100, 0);
} else {
cfg->DisableGpu();
}
cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
}
void profile(bool use_mkldnn = false, bool use_gpu = false) {
AnalysisConfig config;
SetConfig(&config, use_mkldnn, use_gpu);
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&config),
inputs, &outputs, FLAGS_num_threads);
}
TEST(Analyzer_ernie, profile) { profile(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_ernie, profile_mkldnn) { profile(true, false); }
#endif
// Check the model by gpu
#ifdef PADDLE_WITH_CUDA
TEST(Analyzer_ernie, profile_gpu) { profile(false, true); }
#endif
// Check the fuse status
TEST(Analyzer_Ernie, fuse_statis) {
AnalysisConfig cfg;
SetConfig(&cfg);
int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
LOG(INFO) << "num_ops: " << num_ops;
if (FLAGS_ernie_large) {
ASSERT_EQ(fuse_statis.at("fc_fuse"), 146);
EXPECT_EQ(num_ops, 859);
} else {
ASSERT_EQ(fuse_statis.at("fc_fuse"), 74);
EXPECT_EQ(num_ops, 295);
}
}
// Compare result of NativeConfig and AnalysisConfig
void compare(bool use_mkldnn = false) {
AnalysisConfig cfg;
SetConfig(&cfg, use_mkldnn, false);
std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), inputs);
}
TEST(Analyzer_ernie, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_ernie, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
// Compare Deterministic result
TEST(Analyzer_Ernie, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
// Compare results
TEST(Analyzer_Ernie, compare_results) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
auto predictor = CreateTestPredictor(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
FLAGS_use_analysis);
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto outputs_size = outputs.front().data.length() / (sizeof(float));
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j],
static_cast<float *>(outputs[0].data.data())[j],
FLAGS_accuracy);
}
}
}
} // namespace inference
} // namespace paddle
......@@ -44,6 +44,7 @@ DEFINE_string(int8_model, "", "INT8 model path");
DEFINE_string(infer_data, "", "data file");
DEFINE_string(refer_result, "", "reference result for comparison");
DEFINE_int32(batch_size, 1, "batch size");
DEFINE_bool(ernie_large, false, "Test ernie large");
DEFINE_bool(with_accuracy_layer, true,
"Calculate the accuracy while label is in the input");
DEFINE_bool(enable_fp32, true, "Enable FP32 type prediction");
......
......@@ -32,17 +32,33 @@ class FCOp : public framework::OperatorWithKernel {
auto in_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("W");
bool padding_weights = ctx->Attrs().Get<bool>("padding_weights");
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
if (bias_dims.size() == 2) {
PADDLE_ENFORCE_EQ(bias_dims[0], 1,
"The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
"The shape of Bias must be [1, dim].");
platform::errors::InvalidArgument(
"The shape of Bias is invalid."
"The height of Bias should be 1."
"But received height of Bias is %d.",
bias_dims[0]));
PADDLE_ENFORCE_EQ(
bias_dims[1], w_dims1,
platform::errors::InvalidArgument(
"The shape of Bias is invalid."
"The width of Bias should be equal to width of Weight."
"But received width of Bias is %d and width of Weight is %d.",
bias_dims[1], w_dims1));
} else if (bias_dims.size() == 1) {
PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1],
"The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(
bias_dims[0], w_dims1,
platform::errors::InvalidArgument(
"The shape of Bias is invalid."
"The height of Bias should be equal to the width of weight."
"But received height of Bias is %d and width of Weight is %d.",
bias_dims[0], w_dims1));
}
}
......@@ -65,7 +81,8 @@ class FCOp : public framework::OperatorWithKernel {
"in_num_col_dims.");
std::vector<int64_t> output_dims;
FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims);
FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims,
padding_weights);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("Input", "Out");
......@@ -138,6 +155,11 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>(
"padding_weights",
"(bool, default false) When padding weights in the fc fuse pass, "
"the 'padding_weights' attribute is set as true.")
.SetDefault(false);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
"Skip calling InferShape() function in the runtime.")
.SetDefault(true);
......
......@@ -38,17 +38,21 @@ class FCOpGrad : public framework::OperatorWithKernel {
inline void FCOutputSize(const framework::DDim& in_dims,
const framework::DDim& w_dims,
std::vector<int64_t>& out_dims, // NOLINT
int in_num_col_dims) {
int in_num_col_dims, bool padding_weights) {
auto in_mat_dims = framework::flatten_to_2d(in_dims, in_num_col_dims);
PADDLE_ENFORCE_EQ(
in_mat_dims[1], w_dims[0],
"Fully Connected input and weigth size do not match. %s, %s");
auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0];
auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
PADDLE_ENFORCE_EQ(in_mat_dims[1], w_dims0,
platform::errors::InvalidArgument(
"Fully Connected input and weigth size do not match. "
"input width: %d,weight height: %d",
in_mat_dims[1], w_dims0));
out_dims.reserve(static_cast<size_t>(in_num_col_dims + 1));
for (int i = 0; i < in_num_col_dims; ++i) {
out_dims.push_back(in_dims[i]);
}
out_dims.push_back(w_dims[1]);
out_dims.push_back(w_dims1);
}
template <typename DeviceContext, typename T>
......@@ -64,14 +68,18 @@ class FCOpKernel : public framework::OpKernel<T> {
(ctx.Attr<std::string>("activation_type") == "relu") ? true : false;
auto w_dims = w->dims();
bool padding_weights = ctx.Attr<bool>("padding_weights");
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims);
FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims,
padding_weights);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());
auto out_dims = output->dims();
int M = framework::product(out_dims) / w_dims[1];
auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0];
auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
int M = framework::product(out_dims) / w_dims1;
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
......@@ -79,8 +87,8 @@ class FCOpKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, M, w_dims[1], w_dims[0], input_data, w_data, output_data,
bias ? bias->data<T>() : NULL, with_relu);
fc(dev_ctx, M, w_dims1, w_dims0, input_data, w_data, output_data,
bias ? bias->data<T>() : NULL, with_relu, padding_weights);
}
};
......
......@@ -26,131 +26,141 @@ namespace intrinsic {
void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height,
const float epsilon, int right) {
__m256 sum;
__m256 mean_vec, var_vec;
__m128 hi, lo;
__m256 tmp;
size_t offset;
size_t j;
int block = YMM_FLOAT_BLOCK;
const int rest = right % block;
const int end = right - rest;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel
{
#endif
__m256 sum;
__m256 mean_vec, var_vec;
__m128 hi, lo;
__m256 tmp;
size_t offset;
size_t j;
__m256 reverse_num_vec =
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right));
__m256 epsilon_vec = _mm256_set1_ps(epsilon);
int rest_mask =
((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff;
__m256i mask_vec = _mm256_set_epi32(
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0,
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0,
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0,
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0);
__m256 reverse_num_vec =
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right));
__m256 epsilon_vec = _mm256_set1_ps(epsilon);
int rest_mask =
((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff;
__m256i mask_vec = _mm256_set_epi32(
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0,
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0,
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0,
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0);
#ifdef PADDLE_WITH_MKLML
#pragma omp for
#endif
for (int i = 0; i < height; ++i) {
offset = i * right;
for (int i = 0; i < height; ++i) {
offset = i * right;
/* get mean */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j));
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)x + j);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
mean_vec = _mm256_mul_ps(sum, reverse_num_vec);
mean[i] = *reinterpret_cast<float*>(&mean_vec);
/* get variance */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
sum = _mm256_add_ps(sum, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
var_vec = _mm256_mul_ps(sum, reverse_num_vec);
var[i] = *reinterpret_cast<float*>(&var_vec);
/* get x_norm and calculate output*/
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp,
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp,
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (scale) {
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
}
/* get mean */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)scale + j - offset)));
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j));
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(tmp,
_mm256_loadu_ps((const float*)scale + j - offset)));
tmp = _mm256_loadu_ps((const float*)x + j);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
mean_vec = _mm256_mul_ps(sum, reverse_num_vec);
mean[i] = *reinterpret_cast<float*>(&mean_vec);
if (bias) {
/* get variance */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
sum = _mm256_add_ps(sum, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
var_vec = _mm256_mul_ps(sum, reverse_num_vec);
var[i] = *reinterpret_cast<float*>(&var_vec);
/* get x_norm and calculate output*/
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_add_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)bias + j - offset)));
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j,
_mm256_add_ps(tmp, _mm256_loadu_ps((const float*)bias +
j - offset)));
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (scale) {
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
}
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)scale + j - offset)));
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(tmp,
_mm256_loadu_ps((const float*)scale + j - offset)));
}
}
if (bias) {
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
}
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_add_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)bias + j - offset)));
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_add_ps(tmp,
_mm256_loadu_ps((const float*)bias + j - offset)));
}
}
}
#ifdef PADDLE_WITH_MKLML
}
#endif
}
bool LayerNormKernel::CanBeUsed(const int& d) const {
......
......@@ -25,31 +25,56 @@ class FCFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false) {
const T* B = nullptr, bool relu = false,
bool padding_weights = false) {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
blas.MatMul(M, N, K, X, W, Y);
if (B == NULL) {
return;
}
if (relu) {
auto compute =
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache()
.At(N);
framework::Tensor Y1;
T* Y1_data = nullptr;
if (padding_weights) {
const int NN = N + 4;
const int KK = K + 4;
framework::Tensor X1;
T* X1_data = X1.mutable_data<T>({M * KK}, platform::CPUPlace());
Y1_data = Y1.mutable_data<T>({M * (N + 4)}, platform::CPUPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
memcpy(X1_data + i * KK, X + i * K, K * sizeof(T));
}
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X1_data, KK, W, NN,
static_cast<T>(0.0), Y1_data, NN);
} else {
auto compute =
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(
N);
blas.MatMul(M, N, K, X, W, Y);
}
if (B == NULL) {
if (padding_weights) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
for (int i = 0; i < M; i++) {
memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T));
}
}
PADDLE_ENFORCE_EQ(relu, false,
platform::errors::PermissionDenied(
"When bias is NULL, relu can not be true."));
return;
}
auto compute =
relu
? jit::KernelFuncs<jit::VAddReluTuple<T>,
platform::CPUPlace>::Cache()
.At(N)
: jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache()
.At(N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N);
}
}
};
......
......@@ -41,7 +41,12 @@ class FCFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false) {
const T* B = nullptr, bool relu = false,
bool padding_weights = false) {
PADDLE_ENFORCE_EQ(
padding_weights, false,
platform::errors::PermissionDenied(
"Weight padding in fc can not be used in GPU scope."));
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X, K, W, N,
static_cast<T>(0.0), Y, N);
......
......@@ -26,7 +26,8 @@ class FCFunctor {
public:
void operator()(const DeviceContext& context, const int M, const int N,
const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false);
const T* B = nullptr, bool relu = false,
bool weight_pass = false);
};
} // namespace math
......
......@@ -207,8 +207,13 @@ class FCPrimitiveFactory {
void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input,
const Tensor* w, LoDTensor* output) {
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
bool padding_weights = ctx.Attr<bool>("padding_weights");
PADDLE_ENFORCE_EQ(padding_weights, false,
platform::errors::PermissionDenied(
"Weight padding in fc can not be used in MKLDNN."));
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims);
FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims,
padding_weights);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());
}
......
......@@ -124,6 +124,13 @@ class TestFCOpWithBias3(TestFCOp):
self.matrix = MatrixGenerate(1, 64, 32, 3, 3, 1)
class TestFCOpWithPadding(TestFCOp):
def config(self):
self.with_bias = True
self.with_relu = True
self.matrix = MatrixGenerate(1, 4, 3, 128, 128, 2)
class TestFCOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册