Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
e92f0021
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e92f0021
编写于
7月 10, 2017
作者:
X
xzl
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into mobilenet_gpu
上级
a3ce6aa8
1038bc46
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
274 addition
and
50 deletion
+274
-50
CMakeLists.txt
CMakeLists.txt
+1
-0
cmake/generic.cmake
cmake/generic.cmake
+10
-10
paddle/optimizer/adadelta_optimizer.cc
paddle/optimizer/adadelta_optimizer.cc
+5
-3
paddle/optimizer/adagrad_optimizer.cc
paddle/optimizer/adagrad_optimizer.cc
+6
-3
paddle/optimizer/adam_optimizer.cc
paddle/optimizer/adam_optimizer.cc
+6
-3
paddle/optimizer/lr_policy.h
paddle/optimizer/lr_policy.h
+34
-14
paddle/optimizer/sgd_optimizer.cc
paddle/optimizer/sgd_optimizer.cc
+5
-1
paddle/platform/CMakeLists.txt
paddle/platform/CMakeLists.txt
+2
-0
paddle/platform/device_context.h
paddle/platform/device_context.h
+159
-0
paddle/platform/device_context_test.cc
paddle/platform/device_context_test.cc
+33
-0
proto/OptimizerConfig.proto
proto/OptimizerConfig.proto
+12
-15
python/setup.py.in
python/setup.py.in
+1
-1
未找到文件。
CMakeLists.txt
浏览文件 @
e92f0021
...
...
@@ -16,6 +16,7 @@ cmake_minimum_required(VERSION 3.0)
set
(
CMAKE_MODULE_PATH
${
CMAKE_MODULE_PATH
}
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake"
)
set
(
PROJ_ROOT
${
CMAKE_CURRENT_SOURCE_DIR
}
)
set
(
PROJ_BINARY_ROOT
${
CMAKE_CURRENT_BINARY_DIR
}
)
include
(
system
)
...
...
cmake/generic.cmake
浏览文件 @
e92f0021
...
...
@@ -88,7 +88,7 @@
#
# including binary directory for generated headers.
include_directories
(
${
CMAKE_BINARY_DIR
}
)
include_directories
(
${
CMAKE_
CURRENT_
BINARY_DIR
}
)
if
(
NOT APPLE
)
find_package
(
Threads REQUIRED
)
...
...
@@ -106,7 +106,7 @@ function(merge_static_libs TARGET_NAME)
if
(
APPLE
)
# Use OSX's libtool to merge archives
# To produce a library we need at least one source file.
# It is created by add_custom_command below and will helps
# It is created by add_custom_command below and will helps
# also help to track dependencies.
set
(
dummyfile
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
TARGET_NAME
}
_dummy.c
)
...
...
@@ -144,24 +144,24 @@ function(merge_static_libs TARGET_NAME)
DEPENDS
${
lib
}
${
objdir
}
WORKING_DIRECTORY
${
objdir
}
)
# Empty dummy source file that goes into merged library
set
(
mergebase
${
lib
}
.mergebase.c
)
add_custom_command
(
OUTPUT
${
mergebase
}
COMMAND
${
CMAKE_COMMAND
}
-E touch
${
mergebase
}
DEPENDS
${
objlistfile
}
)
# Empty dummy source file that goes into merged library
set
(
mergebase
${
lib
}
.mergebase.c
)
add_custom_command
(
OUTPUT
${
mergebase
}
COMMAND
${
CMAKE_COMMAND
}
-E touch
${
mergebase
}
DEPENDS
${
objlistfile
}
)
list
(
APPEND mergebases
"
${
mergebase
}
"
)
endforeach
()
add_library
(
${
TARGET_NAME
}
STATIC
${
mergebases
}
)
target_link_libraries
(
${
TARGET_NAME
}
${
libs_deps
}
)
target_link_libraries
(
${
TARGET_NAME
}
${
libs_deps
}
)
# Get the file name of the generated library
set
(
outlibfile
"$<TARGET_FILE:
${
TARGET_NAME
}
>"
)
foreach
(
lib
${
libs
}
)
add_custom_command
(
TARGET
${
TARGET_NAME
}
POST_BUILD
COMMAND
${
CMAKE_AR
}
cr
${
outlibfile
}
*.o
COMMAND
${
CMAKE_AR
}
cr
${
outlibfile
}
*.o
COMMAND
${
CMAKE_RANLIB
}
${
outlibfile
}
WORKING_DIRECTORY
${
lib
}
.objdir
)
endforeach
()
...
...
@@ -362,4 +362,4 @@ function(py_proto_compile TARGET_NAME)
set
(
py_srcs
)
protobuf_generate_python
(
py_srcs
${
py_proto_compile_SRCS
}
)
add_custom_target
(
${
TARGET_NAME
}
ALL DEPENDS
${
py_srcs
}
)
endfunction
()
\ No newline at end of file
endfunction
()
paddle/optimizer/adadelta_optimizer.cc
浏览文件 @
e92f0021
...
...
@@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
const
char
*
AdadeltaOptimizer
::
SerializeState
(
int
*
state_len
)
{
AdadeltaOptimizerState
state
;
// TODO(zhihong) : add lr_policy serialization
state
.
set_num_sample_passed
(
num_sample_passed_
);
std
::
string
lr_str
=
this
->
lr_policy_
->
SerializeState
(
state_len
);
state
.
mutable_lr_state
()
->
ParseFromString
(
lr_str
);
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
accum_gradient_
,
state
.
mutable_accum_gradient
());
TensorToProto
(
*
accum_delta_
,
state
.
mutable_accum_delta
());
TensorToProto
(
*
update_delta_
,
state
.
mutable_update_delta
());
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
*
state_len
+
=
str
.
size
();
return
str
.
c_str
();
}
void
AdadeltaOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
AdadeltaOptimizerState
state
;
state
.
ParseFromString
(
str
);
// TODO(zhihong) : add lr_policy DeserializeState
auto
lr_state
=
state
.
lr_state
();
this
->
lr_policy_
->
DeserializeState
(
lr_state
.
SerializeAsString
());
num_sample_passed_
=
state
.
num_sample_passed
();
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
...
...
paddle/optimizer/adagrad_optimizer.cc
浏览文件 @
e92f0021
...
...
@@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
}
const
char
*
AdagradOptimizer
::
SerializeState
(
int
*
state_len
)
{
AdagradOptimizerState
state
;
// TODO(zhihong) : add lr_policy serialization
state
.
set_num_sample_passed
(
num_sample_passed_
);
std
::
string
lr_str
=
this
->
lr_policy_
->
SerializeState
(
state_len
);
state
.
mutable_lr_state
()
->
ParseFromString
(
lr_str
);
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
accum_gradient_
,
state
.
mutable_accum_gradient
());
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
*
state_len
+
=
str
.
size
();
return
str
.
c_str
();
}
void
AdagradOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
AdagradOptimizerState
state
;
state
.
ParseFromString
(
str
);
// TODO(zhihong) : add lr_policy DeserializeState
auto
lr_state
=
state
.
lr_state
();
this
->
lr_policy_
->
DeserializeState
(
lr_state
.
SerializeAsString
());
num_sample_passed_
=
state
.
num_sample_passed
();
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
ProtoToTensor
(
state
.
accum_gradient
(),
accum_gradient_
);
...
...
paddle/optimizer/adam_optimizer.cc
浏览文件 @
e92f0021
...
...
@@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) {
const
char
*
AdamOptimizer
::
SerializeState
(
int
*
state_len
)
{
AdamOptimizerState
state
;
// TODO(zhihong) : add lr_policy serialization
std
::
string
lr_str
=
this
->
lr_policy_
->
SerializeState
(
state_len
);
state
.
mutable_lr_state
()
->
ParseFromString
(
lr_str
);
state
.
set_num_sample_passed
(
num_sample_passed_
);
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
TensorToProto
(
*
momentums_
,
state
.
mutable_momentums
());
TensorToProto
(
*
velocitys_
,
state
.
mutable_velocitys
());
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
*
state_len
+
=
str
.
size
();
return
str
.
c_str
();
}
void
AdamOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
AdamOptimizerState
state
;
state
.
ParseFromString
(
str
);
// TODO(zhihong) : add lr_policy DeserializeState
auto
lr_state
=
state
.
lr_state
();
this
->
lr_policy_
->
DeserializeState
(
lr_state
.
SerializeAsString
());
num_sample_passed_
=
state
.
num_sample_passed
();
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
...
...
paddle/optimizer/lr_policy.h
浏览文件 @
e92f0021
...
...
@@ -17,36 +17,56 @@ public:
// constant learning rate policy
class
ConstLr
final
:
public
LrPolicy
{
public:
ConstLr
(
double
lr
)
:
learning_rate
(
lr
){};
ConstLr
(
double
lr
)
:
learning_rate
_
(
lr
){};
double
LearningRate
(
const
uint64_t
num_sample_passed
)
{
return
learning_rate
;
return
learning_rate_
;
}
const
char
*
SerializeState
(
int
*
state_len
)
{
LrPolicyState
state
;
state
.
set_learning_rate
(
learning_rate_
);
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
return
str
.
c_str
();
}
void
DeserializeState
(
const
std
::
string
&
str
)
{
LrPolicyState
state
;
state
.
ParseFromString
(
str
);
learning_rate_
=
state
.
learning_rate
();
}
const
char
*
SerializeState
(
int
*
state_len
)
{
return
nullptr
;
}
void
DeserializeState
(
const
std
::
string
&
state
)
{}
private:
double
learning_rate
;
double
learning_rate
_
;
};
class
LinearLr
final
:
public
LrPolicy
{
public:
LinearLr
(
double
lr
,
double
lr_decay_a
,
double
lr_decay_b
)
:
learning_rate
(
lr
),
lr_decay_a
(
lr_decay_a
),
lr_decay_b
(
lr_decay_b
)
{}
:
learning_rate
_
(
lr
),
lr_decay_a_
(
lr_decay_a
),
lr_decay_b_
(
lr_decay_b
)
{}
double
LearningRate
(
const
uint64_t
num_sample_passed
)
{
return
std
::
max
(
learning_rate
-
lr_decay_a
*
num_sample_passed
,
lr_decay_b
);
return
std
::
max
(
learning_rate_
-
lr_decay_a_
*
num_sample_passed
,
lr_decay_b_
);
}
const
char
*
SerializeState
(
int
*
state_len
)
{
// TODO(zhihong) : add lr_policy serialization
return
nullptr
;
LrPolicyState
state
;
state
.
set_learning_rate
(
learning_rate_
);
state
.
set_lr_decay_a
(
lr_decay_a_
);
state
.
set_lr_decay_b
(
lr_decay_b_
);
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
return
str
.
c_str
();
}
void
DeserializeState
(
const
std
::
string
&
state
)
{
// TODO(zhihong) : add lr_policy serialization
void
DeserializeState
(
const
std
::
string
&
str
)
{
LrPolicyState
state
;
state
.
ParseFromString
(
str
);
learning_rate_
=
state
.
learning_rate
();
lr_decay_a_
=
state
.
lr_decay_a
();
lr_decay_b_
=
state
.
lr_decay_b
();
}
private:
double
learning_rate
;
double
lr_decay_a
;
double
lr_decay_b
;
double
learning_rate
_
;
double
lr_decay_a
_
;
double
lr_decay_b
_
;
};
}
// namespace optimizer
...
...
paddle/optimizer/sgd_optimizer.cc
浏览文件 @
e92f0021
...
...
@@ -30,16 +30,20 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const
char
*
SGDOptimizer
::
SerializeState
(
int
*
state_len
)
{
SGDOptimizerState
state
;
state
.
set_num_sample_passed
(
num_sample_passed_
);
std
::
string
lr_str
=
this
->
lr_policy_
->
SerializeState
(
state_len
);
state
.
mutable_lr_state
()
->
ParseFromString
(
lr_str
);
TensorToProto
(
*
parameter_
,
state
.
mutable_parameter
());
if
(
momentum_
!=
0.0
)
TensorToProto
(
*
momentums_
,
state
.
mutable_momentums
());
auto
str
=
state
.
SerializeAsString
();
*
state_len
=
str
.
size
();
*
state_len
+
=
str
.
size
();
return
str
.
c_str
();
}
void
SGDOptimizer
::
DeserializeState
(
const
std
::
string
&
str
)
{
SGDOptimizerState
state
;
state
.
ParseFromString
(
str
);
auto
lr_state
=
state
.
lr_state
();
this
->
lr_policy_
->
DeserializeState
(
lr_state
.
SerializeAsString
());
num_sample_passed_
=
state
.
num_sample_passed
();
ProtoToTensor
(
state
.
parameter
(),
parameter_
);
if
(
momentum_
!=
0.0
)
ProtoToTensor
(
state
.
parameter
(),
momentums_
);
...
...
paddle/platform/CMakeLists.txt
浏览文件 @
e92f0021
...
...
@@ -4,3 +4,5 @@ nv_test(cuda_test SRCS cuda_test.cu)
cc_library
(
place SRCS place.cc
)
cc_test
(
place_test SRCS place_test.cc DEPS place glog gflags
)
nv_test
(
device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3 glog gflags
)
paddle/platform/device_context.h
0 → 100644
浏览文件 @
e92f0021
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/enforce.h"
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/cuda.h"
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#define EIGEN_USE_GPU
#endif
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace
paddle
{
namespace
platform
{
class
DeviceContext
{
public:
virtual
~
DeviceContext
()
{}
};
class
CPUDeviceContext
:
public
DeviceContext
{};
#ifndef PADDLE_ONLY_CPU
class
GPUPlaceGuard
{
public:
explicit
GPUPlaceGuard
(
GPUPlace
new_place
)
:
previous_
(
GetCurrentDeviceId
())
{
if
(
previous_
!=
new_place
)
{
paddle
::
platform
::
SetDeviceId
(
new_place
.
device
);
}
}
~
GPUPlaceGuard
()
{
paddle
::
platform
::
SetDeviceId
(
previous_
.
device
);
}
private:
GPUPlace
previous_
;
};
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
const
GPUPlace
gpu_place
)
:
gpu_place_
(
gpu_place
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
paddle
::
platform
::
throw_on_error
(
cudaStreamCreate
(
&
stream_
),
"cudaStreamCreate failed"
);
eigen_stream_
=
new
Eigen
::
CudaStreamDevice
(
&
stream_
);
eigen_device_
=
new
Eigen
::
GpuDevice
(
eigen_stream_
);
}
void
Wait
()
{
paddle
::
platform
::
throw_on_error
(
cudaStreamSynchronize
(
stream_
),
"cudaStreamSynchronize failed"
);
}
cudaStream_t
stream
()
{
return
stream_
;
}
Eigen
::
GpuDevice
eigen_device
()
{
return
*
eigen_device_
;
}
cublasHandle_t
cublas_handle
()
{
if
(
!
blas_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasCreate
(
&
blas_handle_
)
==
CUBLAS_STATUS_SUCCESS
,
"cublasCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasSetStream
(
blas_handle_
,
stream_
)
==
CUBLAS_STATUS_SUCCESS
,
"cublasSetStream failed"
);
}
return
blas_handle_
;
}
cudnnHandle_t
cudnn_handle
()
{
if
(
!
dnn_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnCreate
(
&
dnn_handle_
)
==
CUDNN_STATUS_SUCCESS
,
"cudnnCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnSetStream
(
dnn_handle_
,
stream_
)
==
CUDNN_STATUS_SUCCESS
,
"cudnnSetStream failed"
);
}
return
dnn_handle_
;
}
curandGenerator_t
curand_generator
()
{
if
(
!
rand_generator_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandCreateGenerator
(
&
rand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
)
==
CURAND_STATUS_SUCCESS
,
"curandCreateGenerator failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetPseudoRandomGeneratorSeed
(
rand_generator_
,
random_seed_
)
==
CURAND_STATUS_SUCCESS
,
"curandSetPseudoRandomGeneratorSeed failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetStream
(
rand_generator_
,
stream_
)
==
CURAND_STATUS_SUCCESS
,
"curandSetStream failed"
);
}
return
rand_generator_
;
}
~
CUDADeviceContext
()
{
Wait
();
if
(
blas_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasDestroy
(
blas_handle_
)
==
CUBLAS_STATUS_SUCCESS
,
"cublasDestroy failed"
);
}
if
(
dnn_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnDestroy
(
dnn_handle_
)
==
CUDNN_STATUS_SUCCESS
,
"cudnnDestroy failed"
);
}
if
(
rand_generator_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandDestroyGenerator
(
rand_generator_
)
==
CURAND_STATUS_SUCCESS
,
"curandDestroyGenerator failed"
);
}
delete
eigen_stream_
;
delete
eigen_device_
;
paddle
::
platform
::
throw_on_error
(
cudaStreamDestroy
(
stream_
),
"cudaStreamDestroy failed"
);
}
private:
GPUPlace
gpu_place_
;
cudaStream_t
stream_
;
Eigen
::
CudaStreamDevice
*
eigen_stream_
;
Eigen
::
GpuDevice
*
eigen_device_
;
cublasHandle_t
blas_handle_
{
nullptr
};
cudnnHandle_t
dnn_handle_
{
nullptr
};
int
random_seed_
;
curandGenerator_t
rand_generator_
{
nullptr
};
};
#endif
}
// namespace platform
}
// namespace paddle
paddle/platform/device_context_test.cc
0 → 100644
浏览文件 @
e92f0021
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/device_context.h"
#include "gtest/gtest.h"
TEST
(
CUDADeviceContext
,
Init
)
{
int
count
=
paddle
::
platform
::
GetDeviceCount
();
for
(
int
i
=
0
;
i
<
count
;
i
++
)
{
paddle
::
platform
::
CUDADeviceContext
*
device_context
=
new
paddle
::
platform
::
CUDADeviceContext
(
i
);
Eigen
::
GpuDevice
gpu_device
=
device_context
->
eigen_device
();
ASSERT_NE
(
nullptr
,
gpu_device
.
stream
());
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
curandGenerator_t
curand_handle
=
device_context
->
curand_generator
();
ASSERT_NE
(
nullptr
,
curand_handle
);
delete
device_context
;
}
}
proto/OptimizerConfig.proto
浏览文件 @
e92f0021
...
...
@@ -78,11 +78,15 @@ enum DataType {
repeated
bytes
content
=
2
;
}
message
LrPolicyState
{
// learninRate Policy
optional
double
learning_rate
=
1
[
default
=
1.0
];
optional
double
lr_decay_a
=
2
;
optional
double
lr_decay_b
=
3
;
}
message
SGDOptimizerState
{
// learning rate policy
optional
double
learning_rate
=
101
;
optional
double
lr_decay_a
=
102
;
optional
double
lr_decay_b
=
103
;
optional
LrPolicyState
lr_state
=
101
;
optional
double
num_sample_passed
=
104
;
// state
optional
TensorProto
parameter
=
1
;
...
...
@@ -91,9 +95,7 @@ message SGDOptimizerState {
message
AdadeltaOptimizerState
{
// learning rate policy
optional
double
learning_rate
=
101
;
optional
double
lr_decay_a
=
102
;
optional
double
lr_decay_b
=
103
;
optional
LrPolicyState
lr_state
=
101
;
optional
double
num_sample_passed
=
104
;
// state
optional
TensorProto
parameter
=
1
;
...
...
@@ -102,11 +104,9 @@ message AdadeltaOptimizerState {
optional
TensorProto
update_delta
=
4
;
}
message
AdagradOptimizerState
{
// learning rate policy
optional
double
learning_rate
=
101
;
optional
double
lr_decay_a
=
102
;
optional
double
lr_decay_b
=
103
;
optional
LrPolicyState
lr_state
=
101
;
optional
double
num_sample_passed
=
104
;
// state
optional
TensorProto
parameter
=
1
;
...
...
@@ -114,10 +114,7 @@ message AdagradOptimizerState {
}
message
AdamOptimizerState
{
// learning rate policy
optional
double
learning_rate
=
101
;
optional
double
lr_decay_a
=
102
;
optional
double
lr_decay_b
=
103
;
optional
LrPolicyState
lr_state
=
101
;
optional
double
num_sample_passed
=
104
;
// state
optional
TensorProto
parameter
=
1
;
...
...
python/setup.py.in
浏览文件 @
e92f0021
...
...
@@ -34,6 +34,6 @@ setup(name='paddle',
'': '${CMAKE_CURRENT_SOURCE_DIR}',
# The paddle.v2.framework.proto will be generated while compiling.
# So that package points to other directory.
'paddle.v2.framework.proto': '${
CMAKE_BINARY_DIR
}/paddle/framework'
'paddle.v2.framework.proto': '${
PROJ_BINARY_ROOT
}/paddle/framework'
},
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录