提交 a5ddd827 编写于 作者: M minqiyang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into...

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_py36_py37_ubuntu_dockerfile
...@@ -342,7 +342,7 @@ paddle.fluid.transpiler.RoundRobin.dispatch ArgSpec(args=['self', 'varlist'], va ...@@ -342,7 +342,7 @@ paddle.fluid.transpiler.RoundRobin.dispatch ArgSpec(args=['self', 'varlist'], va
paddle.fluid.transpiler.RoundRobin.reset ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.RoundRobin.reset ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspilerConfig.__init__ paddle.fluid.transpiler.DistributeTranspilerConfig.__init__
paddle.fluid.nets.simple_img_conv_pool ArgSpec(args=['input', 'num_filters', 'filter_size', 'pool_size', 'pool_stride', 'pool_padding', 'pool_type', 'global_pooling', 'conv_stride', 'conv_padding', 'conv_dilation', 'conv_groups', 'param_attr', 'bias_attr', 'act', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, 'max', False, 1, 0, 1, 1, None, None, None, True)) paddle.fluid.nets.simple_img_conv_pool ArgSpec(args=['input', 'num_filters', 'filter_size', 'pool_size', 'pool_stride', 'pool_padding', 'pool_type', 'global_pooling', 'conv_stride', 'conv_padding', 'conv_dilation', 'conv_groups', 'param_attr', 'bias_attr', 'act', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, 'max', False, 1, 0, 1, 1, None, None, None, True))
paddle.fluid.nets.sequence_conv_pool ArgSpec(args=['input', 'num_filters', 'filter_size', 'param_attr', 'act', 'pool_type'], varargs=None, keywords=None, defaults=(None, 'sigmoid', 'max')) paddle.fluid.nets.sequence_conv_pool ArgSpec(args=['input', 'num_filters', 'filter_size', 'param_attr', 'act', 'pool_type', 'bias_attr'], varargs=None, keywords=None, defaults=(None, 'sigmoid', 'max', None))
paddle.fluid.nets.glu ArgSpec(args=['input', 'dim'], varargs=None, keywords=None, defaults=(-1,)) paddle.fluid.nets.glu ArgSpec(args=['input', 'dim'], varargs=None, keywords=None, defaults=(-1,))
paddle.fluid.nets.scaled_dot_product_attention ArgSpec(args=['queries', 'keys', 'values', 'num_heads', 'dropout_rate'], varargs=None, keywords=None, defaults=(1, 0.0)) paddle.fluid.nets.scaled_dot_product_attention ArgSpec(args=['queries', 'keys', 'values', 'num_heads', 'dropout_rate'], varargs=None, keywords=None, defaults=(1, 0.0))
paddle.fluid.nets.img_conv_group ArgSpec(args=['input', 'conv_num_filter', 'pool_size', 'conv_padding', 'conv_filter_size', 'conv_act', 'param_attr', 'conv_with_batchnorm', 'conv_batchnorm_drop_rate', 'pool_stride', 'pool_type', 'use_cudnn'], varargs=None, keywords=None, defaults=(1, 3, None, None, False, 0.0, 1, 'max', True)) paddle.fluid.nets.img_conv_group ArgSpec(args=['input', 'conv_num_filter', 'pool_size', 'conv_padding', 'conv_filter_size', 'conv_act', 'param_attr', 'conv_with_batchnorm', 'conv_batchnorm_drop_rate', 'pool_stride', 'pool_type', 'use_cudnn'], varargs=None, keywords=None, defaults=(1, 3, None, None, False, 0.0, 1, 'max', True))
......
...@@ -46,6 +46,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { ...@@ -46,6 +46,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
prog_file = other.prog_file; prog_file = other.prog_file;
param_file = other.param_file; param_file = other.param_file;
specify_input_name = other.specify_input_name; specify_input_name = other.specify_input_name;
cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_;
// fields from this. // fields from this.
enable_ir_optim = other.enable_ir_optim; enable_ir_optim = other.enable_ir_optim;
use_feed_fetch_ops = other.use_feed_fetch_ops; use_feed_fetch_ops = other.use_feed_fetch_ops;
...@@ -72,6 +73,7 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) { ...@@ -72,6 +73,7 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) {
prog_file = other.prog_file; prog_file = other.prog_file;
param_file = other.param_file; param_file = other.param_file;
specify_input_name = other.specify_input_name; specify_input_name = other.specify_input_name;
cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_;
// fields from this. // fields from this.
enable_ir_optim = other.enable_ir_optim; enable_ir_optim = other.enable_ir_optim;
use_feed_fetch_ops = other.use_feed_fetch_ops; use_feed_fetch_ops = other.use_feed_fetch_ops;
......
...@@ -35,7 +35,6 @@ ...@@ -35,7 +35,6 @@
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DECLARE_bool(profile); DECLARE_bool(profile);
DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
...@@ -67,7 +66,7 @@ bool AnalysisPredictor::Init( ...@@ -67,7 +66,7 @@ bool AnalysisPredictor::Init(
#endif #endif
// no matter with or without MKLDNN // no matter with or without MKLDNN
paddle::platform::SetNumThreads(FLAGS_paddle_num_threads); paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads());
if (!PrepareScope(parent_scope)) { if (!PrepareScope(parent_scope)) {
return false; return false;
...@@ -160,6 +159,14 @@ bool AnalysisPredictor::PrepareExecutor() { ...@@ -160,6 +159,14 @@ bool AnalysisPredictor::PrepareExecutor() {
return true; return true;
} }
void AnalysisPredictor::SetMkldnnThreadID(int tid) {
#ifdef PADDLE_WITH_MKLDNN
platform::set_cur_thread_id(tid);
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
#endif
}
bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs, bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data, std::vector<PaddleTensor> *output_data,
int batch_size) { int batch_size) {
......
...@@ -69,6 +69,8 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -69,6 +69,8 @@ class AnalysisPredictor : public PaddlePredictor {
framework::Scope *scope() { return scope_.get(); } framework::Scope *scope() { return scope_.get(); }
framework::ProgramDesc &program() { return *inference_program_; } framework::ProgramDesc &program() { return *inference_program_; }
void SetMkldnnThreadID(int tid);
protected: protected:
bool PrepareProgram(const std::shared_ptr<framework::ProgramDesc> &program); bool PrepareProgram(const std::shared_ptr<framework::ProgramDesc> &program);
bool PrepareScope(const std::shared_ptr<framework::Scope> &parent_scope); bool PrepareScope(const std::shared_ptr<framework::Scope> &parent_scope);
......
...@@ -28,7 +28,6 @@ limitations under the License. */ ...@@ -28,7 +28,6 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_bool(profile, false, "Turn on profiler for fluid"); DEFINE_bool(profile, false, "Turn on profiler for fluid");
DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace { namespace {
...@@ -76,7 +75,7 @@ bool NativePaddlePredictor::Init( ...@@ -76,7 +75,7 @@ bool NativePaddlePredictor::Init(
#endif #endif
// no matter with or without MKLDNN // no matter with or without MKLDNN
paddle::platform::SetNumThreads(FLAGS_paddle_num_threads); paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads());
if (config_.use_gpu) { if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device); place_ = paddle::platform::CUDAPlace(config_.device);
......
...@@ -51,9 +51,9 @@ struct AnalysisConfig : public NativeConfig { ...@@ -51,9 +51,9 @@ struct AnalysisConfig : public NativeConfig {
int max_batch_size = 1); int max_batch_size = 1);
bool use_tensorrt() const { return use_tensorrt_; } bool use_tensorrt() const { return use_tensorrt_; }
void EnableMKLDNN();
// NOTE this is just for internal development, please not use it. // NOTE this is just for internal development, please not use it.
// NOT stable yet. // NOT stable yet.
void EnableMKLDNN();
bool use_mkldnn() const { return use_mkldnn_; } bool use_mkldnn() const { return use_mkldnn_; }
friend class ::paddle::AnalysisPredictor; friend class ::paddle::AnalysisPredictor;
......
...@@ -186,6 +186,19 @@ struct NativeConfig : public PaddlePredictor::Config { ...@@ -186,6 +186,19 @@ struct NativeConfig : public PaddlePredictor::Config {
// Specify the variable's name of each input if input tensors don't follow the // Specify the variable's name of each input if input tensors don't follow the
// `feeds` and `fetches` of the phase `save_inference_model`. // `feeds` and `fetches` of the phase `save_inference_model`.
bool specify_input_name{false}; bool specify_input_name{false};
// Set and get the number of cpu math library threads.
void SetCpuMathLibraryNumThreads(int cpu_math_library_num_threads) {
cpu_math_library_num_threads_ = cpu_math_library_num_threads;
}
int cpu_math_library_num_threads() const {
return cpu_math_library_num_threads_;
}
protected:
// number of cpu math library (such as MKL, OpenBlas) threads for each
// instance.
int cpu_math_library_num_threads_{1};
}; };
// A factory to help create different predictors. // A factory to help create different predictors.
......
...@@ -27,6 +27,7 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -27,6 +27,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->device = 0; cfg->device = 0;
cfg->enable_ir_optim = true; cfg->enable_ir_optim = true;
cfg->specify_input_name = true; cfg->specify_input_name = true;
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
} }
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
......
...@@ -53,6 +53,8 @@ std::ostream &operator<<(std::ostream &os, const NativeConfig &config) { ...@@ -53,6 +53,8 @@ std::ostream &operator<<(std::ostream &os, const NativeConfig &config) {
os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n"; os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n";
os << GenSpaces(num_spaces) os << GenSpaces(num_spaces)
<< "specify_input_name: " << config.specify_input_name << "\n"; << "specify_input_name: " << config.specify_input_name << "\n";
os << GenSpaces(num_spaces)
<< "cpu_num_threads: " << config.cpu_math_library_num_threads() << "\n";
num_spaces--; num_spaces--;
os << GenSpaces(num_spaces) << "}\n"; os << GenSpaces(num_spaces) << "}\n";
return os; return os;
......
...@@ -42,6 +42,7 @@ DEFINE_bool(use_analysis, true, ...@@ -42,6 +42,7 @@ DEFINE_bool(use_analysis, true,
"Running the inference program in analysis mode."); "Running the inference program in analysis mode.");
DECLARE_bool(profile); DECLARE_bool(profile);
DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -206,22 +207,23 @@ void TestMultiThreadPrediction( ...@@ -206,22 +207,23 @@ void TestMultiThreadPrediction(
int batch_size = FLAGS_batch_size; int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat; int num_times = FLAGS_repeat;
std::vector<std::thread> threads; std::vector<std::thread> threads;
std::vector<std::unique_ptr<PaddlePredictor>> predictors; auto main_predictor = CreateTestPredictor(config, use_analysis);
predictors.emplace_back(CreateTestPredictor(config, use_analysis));
for (int tid = 1; tid < num_threads; ++tid) {
predictors.emplace_back(predictors.front()->Clone());
}
size_t total_time{0}; size_t total_time{0};
for (int tid = 0; tid < num_threads; ++tid) { for (int tid = 0; tid < num_threads; ++tid) {
threads.emplace_back([&, tid]() { threads.emplace_back([&, tid]() {
#ifdef PADDLE_WITH_MKLDNN
platform::set_cur_thread_id(static_cast<int>(tid) + 1);
#endif
// Each thread should have local inputs and outputs. // Each thread should have local inputs and outputs.
// The inputs of each thread are all the same. // The inputs of each thread are all the same.
std::vector<PaddleTensor> outputs_tid; std::vector<PaddleTensor> outputs_tid;
auto &predictor = predictors[tid]; // To ensure the thread binding correctly,
// please clone inside the threadpool.
auto predictor = main_predictor->Clone();
#ifdef PADDLE_WITH_MKLDNN
if (use_analysis) {
static_cast<AnalysisPredictor *>(predictor.get())
->SetMkldnnThreadID(static_cast<int>(tid) + 1);
}
#endif
// warmup run // warmup run
LOG(INFO) << "Running thread " << tid << ", warm up run..."; LOG(INFO) << "Running thread " << tid << ", warm up run...";
......
...@@ -17,8 +17,6 @@ limitations under the License. */ ...@@ -17,8 +17,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -43,7 +41,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -43,7 +41,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
.template Get<jitkernel::VAddKernel<T>>(N); .template Get<jitkernel::VAddKernel<T>>(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1) #pragma omp parallel for
#endif #endif
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
......
...@@ -41,7 +41,7 @@ void SetNumThreads(int num_threads) { ...@@ -41,7 +41,7 @@ void SetNumThreads(int num_threads) {
#elif defined(PADDLE_WITH_MKLML) #elif defined(PADDLE_WITH_MKLML)
int real_num_threads = num_threads > 1 ? num_threads : 1; int real_num_threads = num_threads > 1 ? num_threads : 1;
platform::dynload::MKL_Set_Num_Threads(real_num_threads); platform::dynload::MKL_Set_Num_Threads(real_num_threads);
omp_set_num_threads(num_threads); omp_set_num_threads(real_num_threads);
#else #else
PADDLE_ENFORCE(false, "To be implemented."); PADDLE_ENFORCE(false, "To be implemented.");
#endif #endif
......
...@@ -250,7 +250,8 @@ def sequence_conv_pool(input, ...@@ -250,7 +250,8 @@ def sequence_conv_pool(input,
filter_size, filter_size,
param_attr=None, param_attr=None,
act="sigmoid", act="sigmoid",
pool_type="max"): pool_type="max",
bias_attr=None):
""" """
The sequence_conv_pool is composed with Sequence Convolution and Pooling. The sequence_conv_pool is composed with Sequence Convolution and Pooling.
...@@ -266,6 +267,11 @@ def sequence_conv_pool(input, ...@@ -266,6 +267,11 @@ def sequence_conv_pool(input,
pool_type (str): Pooling type can be :math:`max` for max-pooling, :math:`average` for pool_type (str): Pooling type can be :math:`max` for max-pooling, :math:`average` for
average-pooling, :math:`sum` for sum-pooling, :math:`sqrt` for sqrt-pooling. average-pooling, :math:`sum` for sum-pooling, :math:`sqrt` for sqrt-pooling.
Default :math:`max`. Default :math:`max`.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of sequence_conv.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, sequence_conv
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
Return: Return:
Variable: The final result after Sequence Convolution and Pooling. Variable: The final result after Sequence Convolution and Pooling.
...@@ -289,6 +295,7 @@ def sequence_conv_pool(input, ...@@ -289,6 +295,7 @@ def sequence_conv_pool(input,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr,
act=act) act=act)
pool_out = layers.sequence_pool(input=conv_out, pool_type=pool_type) pool_out = layers.sequence_pool(input=conv_out, pool_type=pool_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册