Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ac78cc04
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ac78cc04
编写于
4月 12, 2018
作者:
_青葱
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into ini
上级
c97003df
8d4d6eae
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
131 addition
and
73 deletion
+131
-73
cmake/cblas.cmake
cmake/cblas.cmake
+19
-15
cmake/external/grpc.cmake
cmake/external/grpc.cmake
+3
-3
cmake/external/snappy.cmake
cmake/external/snappy.cmake
+8
-8
cmake/external/snappystream.cmake
cmake/external/snappystream.cmake
+7
-7
cmake/generic.cmake
cmake/generic.cmake
+2
-13
cmake/inference_lib.cmake
cmake/inference_lib.cmake
+32
-0
paddle/CMakeLists.txt
paddle/CMakeLists.txt
+1
-1
paddle/fluid/CMakeLists.txt
paddle/fluid/CMakeLists.txt
+2
-1
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+2
-2
paddle/fluid/inference/io.cc
paddle/fluid/inference/io.cc
+6
-0
paddle/fluid/inference/io.h
paddle/fluid/inference/io.h
+3
-0
paddle/fluid/inference/tests/book/CMakeLists.txt
paddle/fluid/inference/tests/book/CMakeLists.txt
+1
-1
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+3
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+13
-8
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+23
-10
python/paddle/fluid/param_attr.py
python/paddle/fluid/param_attr.py
+6
-3
未找到文件。
cmake/cblas.cmake
浏览文件 @
ac78cc04
...
...
@@ -62,29 +62,33 @@ endif()
## Then find the reference-cblas. www.netlib.org/blas/
set
(
REFERENCE_CBLAS_ROOT $ENV{REFERENCE_CBLAS_ROOT} CACHE PATH
"Folder contains reference-cblas"
)
set
(
REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS
if
(
NOT CMAKE_CROSSCOMPILING
)
set
(
REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS
${
REFERENCE_CBLAS_ROOT
}
/include
/usr/include
/usr/include/cblas
)
)
set
(
REFERENCE_CBLAS_LIB_SEARCH_PATHS
set
(
REFERENCE_CBLAS_LIB_SEARCH_PATHS
${
REFERENCE_CBLAS_ROOT
}
/lib
/usr/lib
/usr/lib/blas/reference/
/usr/lib/reference/
)
)
else
()
# Diable the finding of reference cblas under host's system path
set
(
REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS
${
REFERENCE_CBLAS_ROOT
}
/include
)
set
(
REFERENCE_CBLAS_LIB_SEARCH_PATHS
${
REFERENCE_CBLAS_ROOT
}
/lib
)
endif
()
find_path
(
REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS
${
REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS
}
)
find_library
(
REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS
${
REFERENCE_CBLAS_LIB_SEARCH_PATHS
}
)
if
(
REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY
)
if
(
REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY
)
set
(
CBLAS_FOUND ON
)
set
(
CBLAS_PROVIDER REFERENCE
)
set
(
CBLAS_INC_DIR
${
REFERENCE_CBLAS_INCLUDE_DIR
}
)
...
...
cmake/external/grpc.cmake
浏览文件 @
ac78cc04
...
...
@@ -24,16 +24,16 @@ SET(GRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/grpc)
SET
(
GRPC_INCLUDE_DIR
"
${
GRPC_INSTALL_DIR
}
/include/"
CACHE PATH
"grpc include directory."
FORCE
)
SET
(
GRPC_CPP_PLUGIN
"
${
GRPC_INSTALL_DIR
}
/bin/grpc_cpp_plugin"
CACHE FILEPATH
"GRPC_CPP_PLUGIN"
FORCE
)
IF
(
APPLE
)
SET
(
BUILD_CMD make -n HAS_SYSTEM_PROTOBUF=false -s -j
8
static grpc_cpp_plugin | sed
"s/-Werror//g"
| sh
)
SET
(
BUILD_CMD make -n HAS_SYSTEM_PROTOBUF=false -s -j static grpc_cpp_plugin | sed
"s/-Werror//g"
| sh
)
ELSE
()
SET
(
BUILD_CMD make HAS_SYSTEM_PROTOBUF=false -s -j
8
static grpc_cpp_plugin
)
SET
(
BUILD_CMD make HAS_SYSTEM_PROTOBUF=false -s -j static grpc_cpp_plugin
)
ENDIF
()
ExternalProject_Add
(
extern_grpc
DEPENDS protobuf zlib
GIT_REPOSITORY
"https://github.com/grpc/grpc.git"
GIT_TAG
"v1.
8
.x"
GIT_TAG
"v1.
11
.x"
PREFIX
${
GRPC_SOURCES_DIR
}
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
...
...
cmake/external/snappy.cmake
浏览文件 @
ac78cc04
...
...
@@ -11,19 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
IF
(
MOBILE_INFERENCE
)
if
(
MOBILE_INFERENCE OR RPI
)
return
()
ENDIF
()
endif
()
include
(
ExternalProject
)
# NOTE: snappy is needed when linking with recordio
SET
(
SNAPPY_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/snappy
)
SET
(
SNAPPY_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/snappy
)
SET
(
SNAPPY_INCLUDE_DIR
"
${
SNAPPY_INSTALL_DIR
}
/include/"
CACHE PATH
"snappy include directory."
FORCE
)
set
(
SNAPPY_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/snappy
)
set
(
SNAPPY_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/snappy
)
set
(
SNAPPY_INCLUDE_DIR
"
${
SNAPPY_INSTALL_DIR
}
/include"
CACHE PATH
"snappy include directory."
FORCE
)
set
(
SNAPPY_LIBRARIES
"
${
SNAPPY_INSTALL_DIR
}
/lib/libsnappy.a"
)
ExternalProject_Add
(
extern_snappy
...
...
@@ -51,8 +52,7 @@ ExternalProject_Add(
)
add_library
(
snappy STATIC IMPORTED GLOBAL
)
set_property
(
TARGET snappy PROPERTY IMPORTED_LOCATION
"
${
SNAPPY_INSTALL_DIR
}
/lib/libsnappy.a"
)
set_property
(
TARGET snappy PROPERTY IMPORTED_LOCATION
${
SNAPPY_LIBRARIES
}
)
include_directories
(
${
SNAPPY_INCLUDE_DIR
}
)
add_dependencies
(
snappy extern_snappy
)
cmake/external/snappystream.cmake
浏览文件 @
ac78cc04
...
...
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
IF
(
MOBILE_INFERENCE
)
IF
(
MOBILE_INFERENCE
OR RPI
)
return
()
ENDIF
()
...
...
@@ -21,9 +20,11 @@ include (ExternalProject)
# NOTE: snappy is needed when linking with recordio
SET
(
SNAPPYSTREAM_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/snappy_stream
)
SET
(
SNAPPYSTREAM_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/snappy_stream
)
SET
(
SNAPPYSTREAM_INCLUDE_DIR
"
${
SNAPPYSTREAM_INSTALL_DIR
}
/include/"
CACHE PATH
"snappy stream include directory."
FORCE
)
set
(
SNAPPYSTREAM_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/snappy_stream
)
set
(
SNAPPYSTREAM_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/snappy_stream
)
set
(
SNAPPYSTREAM_INCLUDE_DIR
"
${
SNAPPYSTREAM_INSTALL_DIR
}
/include"
CACHE PATH
"snappy stream include directory."
FORCE
)
set
(
SNAPPYSTREAM_LIBRARIES
"
${
SNAPPYSTREAM_INSTALL_DIR
}
/lib/libsnappystream.a"
)
ExternalProject_Add
(
extern_snappystream
...
...
@@ -51,8 +52,7 @@ ExternalProject_Add(
)
add_library
(
snappystream STATIC IMPORTED GLOBAL
)
set_property
(
TARGET snappystream PROPERTY IMPORTED_LOCATION
"
${
SNAPPYSTREAM_INSTALL_DIR
}
/lib/libsnappystream.a"
)
set_property
(
TARGET snappystream PROPERTY IMPORTED_LOCATION
${
SNAPPYSTREAM_LIBRARIES
}
)
include_directories
(
${
SNAPPYSTREAM_INCLUDE_DIR
}
)
# For snappysteam to include its own headers.
include_directories
(
${
THIRD_PARTY_PATH
}
/install
)
# For Paddle to include snappy stream headers.
...
...
cmake/generic.cmake
浏览文件 @
ac78cc04
...
...
@@ -195,14 +195,7 @@ function(cc_library TARGET_NAME)
list
(
REMOVE_ITEM cc_library_DEPS warpctc
)
add_dependencies
(
${
TARGET_NAME
}
warpctc
)
endif
()
if
(
"
${
cc_library_DEPS
}
"
MATCHES
"ARCHIVE_START"
)
# Support linking flags: --whole-archive (Linux) / -force_load (MacOS).
# WARNING: Please don't use ARCHIVE_START&ARCHIVE_END if TARGET_NAME will be linked by other libraries.
target_circle_link_libraries
(
${
TARGET_NAME
}
${
cc_library_DEPS
}
)
list
(
REMOVE_ITEM cc_library_DEPS ARCHIVE_START ARCHIVE_END
)
else
()
target_link_libraries
(
${
TARGET_NAME
}
${
cc_library_DEPS
}
)
endif
()
add_dependencies
(
${
TARGET_NAME
}
${
cc_library_DEPS
}
)
endif
()
...
...
@@ -243,11 +236,7 @@ function(cc_test TARGET_NAME)
set
(
multiValueArgs SRCS DEPS ARGS
)
cmake_parse_arguments
(
cc_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
add_executable
(
${
TARGET_NAME
}
${
cc_test_SRCS
}
)
# Support linking flags: --whole-archive (Linux) / -force_load (MacOS)
target_circle_link_libraries
(
${
TARGET_NAME
}
${
cc_test_DEPS
}
paddle_gtest_main memory gtest gflags glog
)
if
(
"
${
cc_test_DEPS
}
"
MATCHES
"ARCHIVE_START"
)
list
(
REMOVE_ITEM cc_test_DEPS ARCHIVE_START ARCHIVE_END
)
endif
()
target_link_libraries
(
${
TARGET_NAME
}
${
cc_test_DEPS
}
paddle_gtest_main memory gtest gflags glog
)
add_dependencies
(
${
TARGET_NAME
}
${
cc_test_DEPS
}
paddle_gtest_main memory gtest gflags glog
)
add_test
(
NAME
${
TARGET_NAME
}
COMMAND
${
TARGET_NAME
}
${
cc_test_ARGS
}
...
...
cmake/inference_lib.cmake
浏览文件 @
ac78cc04
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set_property
(
GLOBAL PROPERTY FLUID_MODULES
""
)
# find all fluid modules is used for paddle fluid static library
function
(
find_fluid_modules TARGET_NAME
)
get_filename_component
(
__target_path
${
TARGET_NAME
}
ABSOLUTE
)
string
(
REGEX REPLACE
"^
${
PADDLE_SOURCE_DIR
}
/"
""
__target_path
${
__target_path
}
)
string
(
FIND
"
${
__target_path
}
"
"fluid"
pos
)
if
(
pos GREATER 1
)
get_property
(
fluid_modules GLOBAL PROPERTY FLUID_MODULES
)
...
...
@@ -77,6 +92,23 @@ elseif (WITH_MKLML)
)
endif
()
if
(
NOT MOBILE_INFERENCE AND NOT RPI
)
set
(
dst_dir
"
${
CMAKE_INSTALL_PREFIX
}
/third_party/install/snappy"
)
copy
(
snappy_lib
SRCS
${
SNAPPY_INCLUDE_DIR
}
${
SNAPPY_LIBRARIES
}
DSTS
${
dst_dir
}
${
dst_dir
}
/lib
)
set
(
dst_dir
"
${
CMAKE_INSTALL_PREFIX
}
/third_party/install/snappystream"
)
copy
(
snappystream_lib
SRCS
${
SNAPPYSTREAM_INCLUDE_DIR
}
${
SNAPPYSTREAM_LIBRARIES
}
DSTS
${
dst_dir
}
${
dst_dir
}
/lib
)
set
(
dst_dir
"
${
CMAKE_INSTALL_PREFIX
}
/third_party/install/zlib"
)
copy
(
zlib_lib
SRCS
${
ZLIB_INCLUDE_DIR
}
${
ZLIB_LIBRARIES
}
DSTS
${
dst_dir
}
${
dst_dir
}
/lib
)
endif
()
# paddle fluid module
set
(
src_dir
"
${
PADDLE_SOURCE_DIR
}
/paddle/fluid"
)
set
(
dst_dir
"
${
CMAKE_INSTALL_PREFIX
}
/paddle/fluid"
)
...
...
paddle/CMakeLists.txt
浏览文件 @
ac78cc04
...
...
@@ -24,6 +24,6 @@ if(NOT WITH_FLUID_ONLY)
endif
()
add_subdirectory
(
testing
)
if
(
NOT MOBILE_INFERENCE AND NOT
ANDROID AND NOT IOS
)
if
(
NOT MOBILE_INFERENCE AND NOT
RPI
)
add_subdirectory
(
fluid
)
endif
()
paddle/fluid/CMakeLists.txt
浏览文件 @
ac78cc04
...
...
@@ -3,6 +3,7 @@ add_subdirectory(platform)
add_subdirectory
(
framework
)
add_subdirectory
(
operators
)
add_subdirectory
(
pybind
)
add_subdirectory
(
inference
)
add_subdirectory
(
string
)
add_subdirectory
(
recordio
)
# NOTE: please add subdirectory inference at last.
add_subdirectory
(
inference
)
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
ac78cc04
set
(
FLUID_CORE_MODULES proto_desc memory lod_tensor executor
prune
init
)
set
(
FLUID_CORE_MODULES proto_desc memory lod_tensor executor init
)
cc_library
(
paddle_fluid_api
SRCS io.cc
...
...
@@ -11,7 +11,7 @@ cc_library(paddle_fluid DEPS ${fluid_modules})
# Create shared library
cc_library
(
paddle_fluid_shared SHARED
SRCS io.cc
DEPS
ARCHIVE_START
${
GLOB_OP_LIB
}
${
FLUID_CORE_MODULES
}
ARCHIVE_END
)
DEPS
${
fluid_modules
}
)
set_target_properties
(
paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid
)
if
(
NOT APPLE
)
# TODO(liuyiqun): Temporarily disable the link flag because it is not support on Mac.
...
...
paddle/fluid/inference/io.cc
浏览文件 @
ac78cc04
...
...
@@ -17,10 +17,16 @@ limitations under the License. */
#include <fstream>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/pybind/pybind.h"
namespace
paddle
{
namespace
inference
{
// Temporarilly add this function for exposing framework::InitDevices() when
// linking the inference shared library.
void
Init
(
bool
init_p2p
)
{
framework
::
InitDevices
(
init_p2p
);
}
void
ReadBinaryFile
(
const
std
::
string
&
filename
,
std
::
string
&
contents
)
{
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s"
,
filename
);
...
...
paddle/fluid/inference/io.h
浏览文件 @
ac78cc04
...
...
@@ -18,12 +18,15 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
inference
{
void
Init
(
bool
init_p2p
);
void
LoadPersistables
(
framework
::
Executor
&
executor
,
framework
::
Scope
&
scope
,
const
framework
::
ProgramDesc
&
main_program
,
const
std
::
string
&
dirname
,
...
...
paddle/fluid/inference/tests/book/CMakeLists.txt
浏览文件 @
ac78cc04
...
...
@@ -17,7 +17,7 @@ function(inference_test TARGET_NAME)
string
(
REGEX REPLACE
"^_$"
""
arg
"
${
arg
}
"
)
cc_test
(
test_inference_
${
TARGET_NAME
}${
arg
}
SRCS test_inference_
${
TARGET_NAME
}
.cc
DEPS
ARCHIVE_START paddle_fluid ARCHIVE_END
DEPS
paddle_fluid
ARGS --dirname=
${
PYTHON_TESTS_DIR
}
/book/
${
TARGET_NAME
}${
arg
}
.inference.model
)
set_tests_properties
(
test_inference_
${
TARGET_NAME
}${
arg
}
PROPERTIES DEPENDS test_
${
TARGET_NAME
}
)
...
...
python/paddle/fluid/framework.py
浏览文件 @
ac78cc04
...
...
@@ -1183,6 +1183,8 @@ class Parameter(Variable):
self
.
gradient_clip_attr
=
kwargs
.
get
(
'gradient_clip_attr'
,
None
)
self
.
do_model_average
=
kwargs
.
get
(
'do_model_average'
,
None
)
def
__str__
(
self
):
return
self
.
to_string
(
True
)
...
...
@@ -1203,7 +1205,7 @@ class Parameter(Variable):
if
with_details
:
res_str
=
Variable
.
to_string
(
self
,
throw_on_error
,
True
)
additional_attr
=
(
"trainable"
,
"optimize_attr"
,
"regularizer"
,
"gradient_clip_attr"
)
"gradient_clip_attr"
,
"do_model_average"
)
for
attr_name
in
additional_attr
:
res_str
+=
"%s: %s
\n
"
%
(
attr_name
,
str
(
getattr
(
self
,
attr_name
)))
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
ac78cc04
...
...
@@ -1516,7 +1516,8 @@ def batch_norm(input,
in_place
=
False
,
name
=
None
,
moving_mean_name
=
None
,
moving_variance_name
=
None
):
moving_variance_name
=
None
,
do_model_average_for_mean_and_var
=
False
):
"""
This function helps create an operator to implement
the BatchNorm layer using the configurations from the input parameters.
...
...
@@ -1547,7 +1548,10 @@ def batch_norm(input,
mean
=
helper
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_mean_name
,
initializer
=
Constant
(
0.0
),
trainable
=
False
),
name
=
moving_mean_name
,
initializer
=
Constant
(
0.0
),
trainable
=
False
,
do_model_average
=
do_model_average_for_mean_and_var
),
shape
=
param_shape
,
dtype
=
input
.
dtype
)
mean
.
stop_gradient
=
True
...
...
@@ -1556,7 +1560,8 @@ def batch_norm(input,
attr
=
ParamAttr
(
name
=
moving_variance_name
,
initializer
=
Constant
(
1.0
),
trainable
=
False
),
trainable
=
False
,
do_model_average
=
do_model_average_for_mean_and_var
),
shape
=
param_shape
,
dtype
=
input
.
dtype
)
variance
.
stop_gradient
=
True
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
ac78cc04
...
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
from
collections
import
defaultdict
from
paddle.fluid.framework
import
Program
import
framework
...
...
@@ -818,8 +818,8 @@ class ModelAverage(Optimizer):
min_average_window, max_average_window and current update times.
Args:
params_grads: A list of parameter-grad variable pairs.
average_window_rate: The rate of average window.
params_grads: A list of parameter-grad variable pairs.
min_average_window: The minimum size of average window.
max_average_window: The maximum size of average window.
...
...
@@ -840,8 +840,8 @@ class ModelAverage(Optimizer):
"""
def
__init__
(
self
,
params_grads
,
average_window_rate
,
params_grads
=
None
,
min_average_window
=
10000
,
max_average_window
=
10000
,
**
kwargs
):
...
...
@@ -849,23 +849,36 @@ class ModelAverage(Optimizer):
self
.
average_window
=
average_window_rate
self
.
min_average_window
=
min_average_window
self
.
max_average_window
=
max_average_window
self
.
params_grads
=
params_grads
self
.
params_grads
=
[]
if
params_grads
is
None
else
params_grads
params
=
{}
for
param
,
grad
in
self
.
params_grads
:
if
param
.
do_model_average
!=
False
:
params
[
param
.
name
]
=
(
param
,
grad
)
for
param
in
framework
.
default_main_program
().
global_block
(
).
all_parameters
():
if
param
.
name
not
in
params
and
param
.
do_model_average
!=
False
:
grad
=
param
.
block
.
create_var
(
name
=
unique_name
.
generate
(
"."
.
join
([
param
.
name
,
'tmp'
])),
dtype
=
param
.
dtype
,
persistable
=
False
,
stop_gradient
=
True
)
params
[
param
.
name
]
=
(
param
,
grad
)
self
.
params_grads
=
params
.
values
()
for
param
,
grad
in
self
.
params_grads
:
if
grad
is
not
None
:
self
.
_append_average_accumulate_op
(
param
)
self
.
apply_program
=
Program
()
block
=
self
.
apply_program
.
global_block
()
with
program_guard
(
main_program
=
self
.
apply_program
):
for
param_grad
in
self
.
params_grads
:
if
param_grad
[
1
]
is
not
None
:
self
.
_add_average_apply_op
(
block
,
param_grad
)
self
.
restore_program
=
Program
()
block
=
self
.
restore_program
.
global_block
()
with
program_guard
(
main_program
=
self
.
restore_program
):
for
param_grad
in
self
.
params_grads
:
if
param_grad
[
1
]
is
not
None
:
self
.
_add_average_restore_op
(
block
,
param_grad
)
def
_add_average_apply_op
(
self
,
block
,
param_grad
):
...
...
python/paddle/fluid/param_attr.py
浏览文件 @
ac78cc04
...
...
@@ -28,13 +28,15 @@ class ParamAttr(object):
learning_rate
=
1.0
,
regularizer
=
None
,
trainable
=
True
,
gradient_clip
=
None
):
gradient_clip
=
None
,
do_model_average
=
None
):
self
.
name
=
name
self
.
initializer
=
initializer
self
.
learning_rate
=
learning_rate
self
.
regularizer
=
regularizer
self
.
trainable
=
trainable
self
.
gradient_clip
=
gradient_clip
self
.
model_average
=
do_model_average
def
set_default_initializer
(
self
,
initializer
):
if
initializer
is
None
:
...
...
@@ -80,7 +82,8 @@ class ParamAttr(object):
},
'regularizer'
:
self
.
regularizer
,
'trainable'
:
self
.
trainable
,
'gradient_clip_attr'
:
self
.
gradient_clip
'gradient_clip_attr'
:
self
.
gradient_clip
,
'model_average'
:
self
.
model_average
}
if
with_initializer
:
kwargs
[
'initializer'
]
=
self
.
initializer
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录