提交 929d22c6 编写于 作者: L Luo Tao

auto set openblas env

上级 cedd9805
......@@ -92,6 +92,9 @@ function threads_config() {
if [ -z "$OPENBLAS_NUM_THREADS" ]; then
export OPENBLAS_NUM_THREADS=$threads
fi
if [ $threads -gt 1 ] && [ -z "$OPENBLAS_MAIN_FREE" ]; then
export OPENBLAS_MAIN_FREE=1
fi
fi
}
......
......@@ -62,12 +62,15 @@ __all__ = [
cp.begin_parse()
def set_omp_mkl_env_vars(trainer_count):
def set_env_vars(trainer_count):
'''Auto set CPU environment if have not set before.
For MKL:
export KMP_AFFINITY, OMP_DYNAMIC according to the Hyper Threading status.
export OMP_NUM_THREADS, MKL_NUM_THREADS according to trainer_count.
For OpenBLAS:
export OPENBLAS_NUM_THREADS, OPENBLAS_MAIN_FREE according to trainer_count.
'''
import platform
import platform, paddle
if not platform.system() in ['Linux', 'Darwin']:
return
......@@ -103,6 +106,7 @@ def set_omp_mkl_env_vars(trainer_count):
num_cores = num_physical_cores()
num_processors = num_logical_processors()
if paddle.version.mkl() == 'ON':
if num_processors > num_cores: # Hyper Threading is enabled
set_env("OMP_DYNAMIC", "true")
set_env("KMP_AFFINITY", "granularity=fine,compact,1,0")
......@@ -111,8 +115,13 @@ def set_omp_mkl_env_vars(trainer_count):
set_env("KMP_AFFINITY", "granularity=fine,compact,0,0")
threads = num_processors / trainer_count
threads = '1' if threads < 1 else str(threads)
if paddle.version.mkl() == 'ON':
set_env("OMP_NUM_THREADS", threads)
set_env("MKL_NUM_THREADS", threads)
else:
set_env("OPENBLAS_NUM_THREADS", threads)
if threads > 1:
set_env("OPENBLAS_MAIN_FREE", '1')
def init(**kwargs):
......@@ -129,7 +138,7 @@ def init(**kwargs):
for key in args_dict.keys():
args.append('--%s=%s' % (key, str(args_dict[key])))
set_omp_mkl_env_vars(kwargs.get('trainer_count', 1))
set_env_vars(kwargs.get('trainer_count', 1))
if 'use_gpu' in kwargs:
cp.g_command_config_args['use_gpu'] = kwargs['use_gpu']
......
......@@ -31,6 +31,7 @@ patch = '%(patch)d'
rc = '%(rc)d'
istaged = %(istaged)s
commit = '%(commit)s'
with_mkl = '%(with_mkl)s'
def show():
if istaged:
......@@ -41,6 +42,9 @@ def show():
print 'rc:', rc
else:
print 'commit:', commit
def mkl():
return with_mkl
'''
commit = git_commit()
with open(filename, 'w') as f:
......@@ -51,7 +55,8 @@ def show():
'rc': RC,
'version': '${PADDLE_VERSION}',
'commit': commit,
'istaged': ISTAGED})
'istaged': ISTAGED,
'with_mkl': '@WITH_MKL@'})
write_version_py(filename='@PADDLE_SOURCE_DIR@/python/paddle/version.py')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册