Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1795e576
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1795e576
编写于
8月 22, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into lookup_table
上级
0f3b9e41
ce723af0
变更
33
显示空白变更内容
内联
并排
Showing
33 changed file
with
393 addition
and
135 deletion
+393
-135
doc/api/v2/config/layer.rst
doc/api/v2/config/layer.rst
+5
-0
paddle/framework/backward.cc
paddle/framework/backward.cc
+5
-4
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+15
-15
paddle/framework/pybind.cc
paddle/framework/pybind.cc
+2
-2
paddle/function/GemmFunctor.cpp
paddle/function/GemmFunctor.cpp
+2
-2
paddle/gserver/gradientmachines/NeuralNetwork.cpp
paddle/gserver/gradientmachines/NeuralNetwork.cpp
+1
-1
paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
+1
-1
paddle/gserver/layers/ScaleShiftLayer.cpp
paddle/gserver/layers/ScaleShiftLayer.cpp
+107
-0
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+15
-0
paddle/gserver/tests/test_NetworkCompare.cpp
paddle/gserver/tests/test_NetworkCompare.cpp
+2
-1
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+13
-16
paddle/operators/gaussian_random_op.cu
paddle/operators/gaussian_random_op.cu
+35
-23
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+0
-1
paddle/operators/net_op.h
paddle/operators/net_op.h
+4
-3
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+5
-5
paddle/operators/rowwise_add_op.h
paddle/operators/rowwise_add_op.h
+10
-10
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+3
-6
paddle/operators/uniform_random_op.cu
paddle/operators/uniform_random_op.cu
+1
-4
paddle/parameter/Parameter.h
paddle/parameter/Parameter.h
+4
-1
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+0
-16
paddle/platform/device_context.h
paddle/platform/device_context.h
+2
-11
paddle/platform/device_context_test.cc
paddle/platform/device_context_test.cc
+0
-2
paddle/pserver/ParameterClient2.cpp
paddle/pserver/ParameterClient2.cpp
+14
-2
paddle/pserver/ParameterClient2.h
paddle/pserver/ParameterClient2.h
+1
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+14
-0
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+42
-0
python/paddle/trainer_config_helpers/tests/configs/file_list.sh
.../paddle/trainer_config_helpers/tests/configs/file_list.sh
+1
-1
python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr
...rs/tests/configs/protostr/test_scale_shift_layer.protostr
+72
-0
python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py
...er_config_helpers/tests/configs/test_scale_shift_layer.py
+9
-0
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+1
-1
python/paddle/v2/framework/tests/test_net.py
python/paddle/v2/framework/tests/test_net.py
+5
-5
python/paddle/v2/framework/tests/test_recurrent_op.py
python/paddle/v2/framework/tests/test_recurrent_op.py
+1
-1
python/paddle/v2/framework/tests/test_rowwise_add_op.py
python/paddle/v2/framework/tests/test_rowwise_add_op.py
+1
-1
未找到文件。
doc/api/v2/config/layer.rst
浏览文件 @
1795e576
...
@@ -362,6 +362,11 @@ trans
...
@@ -362,6 +362,11 @@ trans
.. autoclass:: paddle.v2.layer.trans
.. autoclass:: paddle.v2.layer.trans
:noindex:
:noindex:
scale_shift
-----------
.. autoclass:: paddle.v2.layer.scale_shift
:noindex:
Sampling Layers
Sampling Layers
===============
===============
...
...
paddle/framework/backward.cc
浏览文件 @
1795e576
...
@@ -110,7 +110,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
...
@@ -110,7 +110,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
return
false
;
return
false
;
});
});
net
->
A
d
dOp
(
std
::
move
(
bwd
));
net
->
A
ppen
dOp
(
std
::
move
(
bwd
));
}
}
// Get unique ID for this method.
// Get unique ID for this method.
auto
uid
=
uniq_id
++
;
auto
uid
=
uniq_id
++
;
...
@@ -163,7 +163,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
...
@@ -163,7 +163,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// If part of input gradient of that operator is not calculated, fill
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
// zero variables to that input gradient.
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{{
"Src"
,
{
prefix
}}},
net
->
AppendOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{{
"Src"
,
{
prefix
}}},
{{
"Dst"
,
{
grad_input
}}},
{}));
{{
"Dst"
,
{
grad_input
}}},
{}));
}
}
return
false
;
return
false
;
...
@@ -195,7 +196,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
...
@@ -195,7 +196,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
return
grad_op
;
return
grad_op
;
}
}
net
->
A
d
dOp
(
std
::
move
(
grad_op
));
net
->
A
ppen
dOp
(
std
::
move
(
grad_op
));
}
}
net
->
SetType
(
"@GENERATED_BACKWARD@"
);
net
->
SetType
(
"@GENERATED_BACKWARD@"
);
net
->
CompleteAddOp
();
net
->
CompleteAddOp
();
...
...
paddle/framework/backward_test.cc
浏览文件 @
1795e576
...
@@ -75,13 +75,13 @@ class FcOp : public operators::NetOp {
...
@@ -75,13 +75,13 @@ class FcOp : public operators::NetOp {
FcOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
FcOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
A
d
dOp
(
OpRegistry
::
CreateOp
(
"mul"
,
A
ppen
dOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
{{
"Out"
,
{
Output
(
"mul_result"
)}}},
{}));
{{
"Out"
,
{
Output
(
"mul_result"
)}}},
{}));
auto
input_b
=
Inputs
(
"b"
);
auto
input_b
=
Inputs
(
"b"
);
std
::
string
before_act
=
"mul_result"
;
std
::
string
before_act
=
"mul_result"
;
if
(
input_b
.
size
()
!=
0
)
{
if
(
input_b
.
size
()
!=
0
)
{
A
d
dOp
(
OpRegistry
::
CreateOp
(
A
ppen
dOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{{
"X"
,
{
Output
(
"mul_result"
)}},
{
"b"
,
{
input_b
[
0
]}}},
"rowwise_add"
,
{{
"X"
,
{
Output
(
"mul_result"
)}},
{
"b"
,
{
input_b
[
0
]}}},
{{
"Out"
,
{
Output
(
"add_result"
)}}},
{}));
{{
"Out"
,
{
Output
(
"add_result"
)}}},
{}));
before_act
=
"add_result"
;
before_act
=
"add_result"
;
...
@@ -92,7 +92,7 @@ class FcOp : public operators::NetOp {
...
@@ -92,7 +92,7 @@ class FcOp : public operators::NetOp {
}
}
}
}
A
d
dOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{{
"X"
,
{
Output
(
before_act
)}}},
A
ppen
dOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{{
"X"
,
{
Output
(
before_act
)}}},
{{
"Out"
,
{
Output
(
"Out"
)}}},
{}));
{{
"Out"
,
{
Output
(
"Out"
)}}},
{}));
CompleteAddOp
(
false
);
CompleteAddOp
(
false
);
}
}
...
@@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) {
...
@@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) {
TEST
(
Backward
,
net_input_of_network_not_need_grad
)
{
TEST
(
Backward
,
net_input_of_network_not_need_grad
)
{
ops
::
NetOp
net
;
ops
::
NetOp
net
;
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"W1"
}},
{
"b"
,
{
"b1"
}}},
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"W1"
}},
{
"b"
,
{
"b1"
}}},
{{
"mul_result"
,
{
"mul_tmp_0"
}},
{{
"mul_result"
,
{
"mul_tmp_0"
}},
{
"add_result"
,
{
"add_tmp_0"
}},
{
"add_result"
,
{
"add_tmp_0"
}},
{
"Out"
,
{
"hidden0"
}}},
{
"Out"
,
{
"hidden0"
}}},
{}));
{}));
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"hidden0"
}},
{
"W"
,
{
"W2"
}},
{
"b"
,
{
"b2"
}}},
"fc"
,
{{
"X"
,
{
"hidden0"
}},
{
"W"
,
{
"W2"
}},
{
"b"
,
{
"b2"
}}},
{{
"mul_result"
,
{
"mul_tmp_1"
}},
{{
"mul_result"
,
{
"mul_tmp_1"
}},
{
"add_result"
,
{
"add_tmp_1"
}},
{
"add_result"
,
{
"add_tmp_1"
}},
...
@@ -273,9 +273,9 @@ TEST(Backward, net_input_of_network_not_need_grad) {
...
@@ -273,9 +273,9 @@ TEST(Backward, net_input_of_network_not_need_grad) {
TEST
(
Backward
,
net_shared_weight
)
{
TEST
(
Backward
,
net_shared_weight
)
{
ops
::
NetOp
net
;
ops
::
NetOp
net
;
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"w"
}}},
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"w"
}}},
{{
"Out"
,
{
"out"
}}},
{}));
{{
"Out"
,
{
"out"
}}},
{}));
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"out"
}},
{
"Y"
,
{
"w"
}}},
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"out"
}},
{
"Y"
,
{
"w"
}}},
{{
"Out"
,
{
"FinalOut"
}}},
{}));
{{
"Out"
,
{
"FinalOut"
}}},
{}));
net
.
CompleteAddOp
();
net
.
CompleteAddOp
();
...
@@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) {
...
@@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) {
TEST
(
Backward
,
linear_net_intermediate_variable_has_no_grad
)
{
TEST
(
Backward
,
linear_net_intermediate_variable_has_no_grad
)
{
ops
::
NetOp
net
;
ops
::
NetOp
net
;
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x1"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
"fc"
,
{{
"X"
,
{
"x1"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
{{
"mul_result"
,
{
"mul_out1"
}},
{{
"mul_result"
,
{
"mul_out1"
}},
{
"add_result"
,
{
"add_out1"
}},
{
"add_result"
,
{
"add_out1"
}},
{
"Out"
,
{
"out1"
}}},
{
"Out"
,
{
"out1"
}}},
{}));
{}));
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"out1"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}},
"fc"
,
{{
"X"
,
{
"out1"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}},
{{
"mul_result"
,
{
"mul_out2"
}},
{{
"mul_result"
,
{
"mul_out2"
}},
{
"add_result"
,
{
"tmp_out2"
}},
{
"add_result"
,
{
"tmp_out2"
}},
{
"Out"
,
{
"out2"
}}},
{
"Out"
,
{
"out2"
}}},
{}));
{}));
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"out2"
}},
{
"W"
,
{
"w3"
}},
{
"b"
,
{
"b3"
}}},
"fc"
,
{{
"X"
,
{
"out2"
}},
{
"W"
,
{
"w3"
}},
{
"b"
,
{
"b3"
}}},
{{
"mul_result"
,
{
"mul_out3"
}},
{{
"mul_result"
,
{
"mul_out3"
}},
{
"add_result"
,
{
"tmp_out3"
}},
{
"add_result"
,
{
"tmp_out3"
}},
...
...
paddle/framework/pybind.cc
浏览文件 @
1795e576
...
@@ -223,8 +223,8 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -223,8 +223,8 @@ All parameter, weight, gradient are variables in Paddle.
retv
->
SetType
(
"plain_net"
);
retv
->
SetType
(
"plain_net"
);
return
retv
;
return
retv
;
})
})
.
def
(
"a
d
d_op"
,
[](
operators
::
NetOp
&
self
,
.
def
(
"a
ppen
d_op"
,
[](
operators
::
NetOp
&
self
,
const
OperatorBase
&
op
)
{
self
.
Ad
dOp
(
op
);
})
const
OperatorBase
&
op
)
{
self
.
Appen
dOp
(
op
);
})
.
def
(
"complete_add_op"
,
&
operators
::
NetOp
::
CompleteAddOp
)
.
def
(
"complete_add_op"
,
&
operators
::
NetOp
::
CompleteAddOp
)
.
def
(
"complete_add_op"
,
[](
std
::
shared_ptr
<
operators
::
NetOp
>
&
self
)
{
.
def
(
"complete_add_op"
,
[](
std
::
shared_ptr
<
operators
::
NetOp
>
&
self
)
{
self
->
CompleteAddOp
();
self
->
CompleteAddOp
();
...
...
paddle/function/GemmFunctor.cpp
浏览文件 @
1795e576
...
@@ -84,7 +84,7 @@ struct BlasGemm<DEVICE_TYPE_GPU, T> {
...
@@ -84,7 +84,7 @@ struct BlasGemm<DEVICE_TYPE_GPU, T> {
}
}
};
};
template
class
BlasGemm
<
DEVICE_TYPE_CPU
,
real
>;
template
struct
BlasGemm
<
DEVICE_TYPE_CPU
,
real
>;
template
class
BlasGemm
<
DEVICE_TYPE_GPU
,
real
>;
template
struct
BlasGemm
<
DEVICE_TYPE_GPU
,
real
>;
}
// namespace paddle
}
// namespace paddle
paddle/gserver/gradientmachines/NeuralNetwork.cpp
浏览文件 @
1795e576
...
@@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector<Argument>& inArgs) {
...
@@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector<Argument>& inArgs) {
auto
mat
=
dynamic_cast
<
SparsePrefetchRowCpuMatrix
*>
(
auto
mat
=
dynamic_cast
<
SparsePrefetchRowCpuMatrix
*>
(
para
->
getMat
(
PARAMETER_VALUE
).
get
());
para
->
getMat
(
PARAMETER_VALUE
).
get
());
para
->
clearGradient
();
para
->
clearGradient
();
mat
->
clearIndices
();
if
(
mat
)
mat
->
clearIndices
();
}
}
}
}
}
}
...
...
paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
浏览文件 @
1795e576
...
@@ -184,7 +184,7 @@ public:
...
@@ -184,7 +184,7 @@ public:
}
}
void
backward
(
const
UpdateCallback
&
callback
)
override
{
void
backward
(
const
UpdateCallback
&
callback
)
override
{
if
(
biases_
)
{
if
(
biases_
&&
biases_
->
getWGrad
()
)
{
backwardActivation
();
backwardActivation
();
biases_
->
getWGrad
()
->
collectBias
(
*
getOutputGrad
(),
1
);
biases_
->
getWGrad
()
->
collectBias
(
*
getOutputGrad
(),
1
);
biases_
->
getParameterPtr
()
->
incUpdate
(
callback
);
biases_
->
getParameterPtr
()
->
incUpdate
(
callback
);
...
...
paddle/gserver/layers/ScaleShiftLayer.cpp
0 → 100644
浏览文件 @
1795e576
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
namespace
paddle
{
/**
* A layer applies a linear transformation to each element in each row of
* the input matrix. For each element, the layer first re-scale it and then
* adds a bias to it.
*
* \f[
* y = wx + b
* \f]
*
* Here, w is the scale and b is the bias. Both w and b are trainable scalars.
*
*/
class
ScaleShiftLayer
:
public
Layer
{
protected:
std
::
unique_ptr
<
Weight
>
scale_
;
std
::
unique_ptr
<
Weight
>
offset_
;
public:
explicit
ScaleShiftLayer
(
const
LayerConfig
&
config
)
:
Layer
(
config
)
{}
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
override
;
void
forward
(
PassType
passType
)
override
;
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
)
override
;
};
REGISTER_LAYER
(
scale_shift
,
ScaleShiftLayer
);
bool
ScaleShiftLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
Layer
::
init
(
layerMap
,
parameterMap
);
CHECK_EQ
(
inputLayers_
.
size
(),
1U
);
scale_
.
reset
(
new
Weight
(
1
,
1
,
parameters_
[
0
]));
if
(
biasParameter_
.
get
()
!=
NULL
)
{
offset_
=
std
::
unique_ptr
<
Weight
>
(
new
Weight
(
1
,
1
,
biasParameter_
));
}
return
true
;
}
void
ScaleShiftLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
MatrixPtr
inV
=
getInputValue
(
0
);
resetOutput
(
inV
->
getHeight
(),
inV
->
getWidth
());
MatrixPtr
outV
=
getOutputValue
();
real
scaleValue
=
scale_
->
getW
()
->
getElement
(
0
,
0
);
outV
->
mulScalar
(
*
inV
,
scaleValue
);
if
(
offset_
)
{
real
offsetValue
=
offset_
->
getW
()
->
getElement
(
0
,
0
);
outV
->
add
(
offsetValue
);
}
}
void
ScaleShiftLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
MatrixPtr
inV
=
getInputValue
(
0
);
MatrixPtr
inG
=
getInputGrad
(
0
);
MatrixPtr
outV
=
getOutputValue
();
MatrixPtr
outG
=
getOutputGrad
();
/* Calculate the parameter gradient for the current layer */
if
(
scale_
->
getWGrad
())
{
MatrixPtr
rowSumMtx
;
Matrix
::
resizeOrCreate
(
rowSumMtx
,
outG
->
getHeight
(),
1
,
false
,
useGpu_
);
// this_i = scaleDest * this_i + scaleSum * \sum_j b_{ij} * c_{ij}
rowSumMtx
->
sumOfProducts
(
/* b= */
*
inV
,
/* c= */
*
outG
,
/* scaleSum= */
1
,
/* scaleDest= */
0.
);
// this_i = scaleDest * this_i + scaleSum * \sum_j b_{ji}
scale_
->
getWGrad
()
->
sumCols
(
/* b= */
*
rowSumMtx
,
/* scaleSum= */
1.
,
/* scaleDest= */
1.
);
scale_
->
getParameterPtr
()
->
incUpdate
(
callback
);
}
if
(
offset_
&&
offset_
->
getWGrad
())
{
MatrixPtr
rowSumMtx
;
Matrix
::
resizeOrCreate
(
rowSumMtx
,
outG
->
getHeight
(),
1
,
false
,
useGpu_
);
rowSumMtx
->
sumRows
(
*
outG
,
1.
,
0.
);
offset_
->
getWGrad
()
->
sumCols
(
*
rowSumMtx
,
1.
,
1.
);
offset_
->
getParameterPtr
()
->
incUpdate
(
callback
);
}
/* Calculate the input layers error */
if
(
inG
)
{
real
scaleValue
=
scale_
->
getW
()
->
getElement
(
0
,
0
);
inG
->
add
(
*
outG
,
scaleValue
);
}
}
}
// namespace paddle
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
1795e576
...
@@ -2007,6 +2007,21 @@ TEST(Layer, RowL2NormLayer) {
...
@@ -2007,6 +2007,21 @@ TEST(Layer, RowL2NormLayer) {
}
}
}
}
TEST
(
Layer
,
ScaleShiftLayer
)
{
const
size_t
batchSize
=
16
;
const
size_t
size
=
32
;
TestConfig
config
;
config
.
layerConfig
.
set_type
(
"scale_shift"
);
config
.
layerConfig
.
set_size
(
size
);
config
.
biasSize
=
1
;
config
.
inputDefs
.
push_back
(
{
INPUT_DATA
,
"input"
,
/* dim= */
size
,
/* paraSize= */
1
});
config
.
layerConfig
.
add_inputs
();
for
(
auto
useGpu
:
{
false
,
true
})
{
testLayerGrad
(
config
,
"scale_shift"
,
batchSize
,
false
,
useGpu
,
false
);
}
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
initMain
(
argc
,
argv
);
initMain
(
argc
,
argv
);
...
...
paddle/gserver/tests/test_NetworkCompare.cpp
浏览文件 @
1795e576
...
@@ -269,7 +269,8 @@ TEST(Compare, img_conv2) {
...
@@ -269,7 +269,8 @@ TEST(Compare, img_conv2) {
bool
useGpu
=
FLAGS_use_gpu
;
bool
useGpu
=
FLAGS_use_gpu
;
double
eps
=
FLAGS_checkgrad_eps
;
double
eps
=
FLAGS_checkgrad_eps
;
FLAGS_use_gpu
=
true
;
FLAGS_use_gpu
=
true
;
FLAGS_checkgrad_eps
=
1e-2
;
// Sometimes, this unit test will fail with 1e-2
FLAGS_checkgrad_eps
=
4e-2
;
compareNetwork
(
config_file_a
,
config_file_b
);
compareNetwork
(
config_file_a
,
config_file_b
);
FLAGS_use_gpu
=
useGpu
;
FLAGS_use_gpu
=
useGpu
;
FLAGS_checkgrad_eps
=
eps
;
FLAGS_checkgrad_eps
=
eps
;
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
1795e576
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -19,25 +16,25 @@ namespace paddle {
...
@@ -19,25 +16,25 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
>
class
GaussianRandomKernel
:
public
framework
::
OpKernel
{
class
CPU
GaussianRandomKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
op_
.
GetAttr
<
float
>
(
"mean"
);
float
mean
=
context
.
op_
.
GetAttr
<
float
>
(
"mean"
);
float
std
=
context
.
op_
.
GetAttr
<
float
>
(
"std"
);
float
std
=
context
.
op_
.
GetAttr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// TODO(dzh): attribute does not support unsigned int.
unsigned
int
seed
=
// And we need a global random seed configuration.
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
int
seed
=
context
.
op_
.
GetAttr
<
int
>
(
"seed"
)
;
std
::
minstd_rand
engine
;
if
(
seed
==
0
)
{
if
(
seed
==
0
)
{
seed
=
std
::
random_device
()();
seed
=
std
::
random_device
()();
}
}
std
::
mt19937
g
(
seed
);
engine
.
seed
(
seed
);
std
::
normal_distribution
<
T
>
dist
ribution
(
mean
,
std
);
std
::
normal_distribution
<
T
>
dist
(
mean
,
std
);
ssize_t
size
=
framework
::
product
(
tensor
->
dims
());
ssize_t
size
=
framework
::
product
(
tensor
->
dims
());
for
(
in
t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
ssize_
t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
ribution
(
g
);
data
[
i
]
=
dist
(
engine
);
}
}
}
}
};
};
...
@@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
...
@@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
dims
=
GetAttr
<
std
::
vector
<
int
>>
(
"dims"
);
auto
dims
=
GetAttr
<
std
::
vector
<
int
>>
(
"dims"
);
PADDLE_ENFORCE
(
dims
.
size
()
>
0UL
,
PADDLE_ENFORCE
(
dims
.
size
()
>
0UL
,
"dims can be one int or array. dims must be set."
);
"dims can be one int or array. dims must be set."
);
...
@@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator.
...
@@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator.
)DOC"
);
)DOC"
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"The dimension of random tensor."
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"The dimension of random tensor."
);
AddAttr
<
float
>
(
"mean"
,
"mean
value of random
."
).
SetDefault
(
.0
f
);
AddAttr
<
float
>
(
"mean"
,
"mean
of random tensor
."
).
SetDefault
(
.0
f
);
AddAttr
<
float
>
(
"std"
,
"
minimum value of random value
."
).
SetDefault
(
1.0
f
);
AddAttr
<
float
>
(
"std"
,
"
std of random tensor
."
).
SetDefault
(
1.0
f
);
AddAttr
<
int
>
(
"seed"
,
AddAttr
<
int
>
(
"seed"
,
"Random seed of generator."
"Random seed of generator."
"0 means use system wide seed"
)
"0 means use system wide seed"
)
...
@@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator.
...
@@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator.
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
gaussian_random
,
ops
::
GaussianRandomOp
,
REGISTER_OP_WITHOUT_GRADIENT
(
gaussian_random
,
ops
::
GaussianRandomOp
,
ops
::
GaussianRandomOpMaker
);
ops
::
GaussianRandomOpMaker
);
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
ops
::
GaussianRandomKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
ops
::
CPU
GaussianRandomKernel
<
float
>
);
\ No newline at end of file
paddle/operators/gaussian_random_op.cu
浏览文件 @
1795e576
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <memory>
#include <thrust/device_ptr.h>
#include <random>
#include <thrust/iterator/counting_iterator.h>
#include "paddle/platform/dynload/curand.h"
#include <thrust/random.h>
#include "paddle/platform/gpu_info.h"
#include <thrust/transform.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
>
class
GaussianRandomKernel
:
public
framework
::
OpKernel
{
struct
GaussianGenerator
{
T
mean_
,
std_
;
unsigned
int
seed_
;
__host__
__device__
GaussianGenerator
(
T
mean
,
T
std
,
int
seed
)
:
mean_
(
mean
),
std_
(
std
),
seed_
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
normal_distribution
<
T
>
dist
(
mean_
,
std_
);
rng
.
discard
(
n
);
return
dist
(
rng
);
}
};
template
<
typename
T
>
class
GPUGaussianRandomKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
op_
.
GetAttr
<
float
>
(
"mean"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
float
std
=
context
.
op_
.
GetAttr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
int
seed
=
context
.
op_
.
GetAttr
<
int
>
(
"seed"
);
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
)
);
if
(
seed
==
0
)
{
if
(
seed
==
0
)
{
std
::
random_device
rd
;
std
::
random_device
rd
;
seed
=
rd
();
seed
=
rd
();
}
}
curandGenerator_t
g
;
T
mean
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"mean"
))
;
PADDLE_ENFORCE
(
platform
::
dynload
::
curandCreateGenerator
(
T
std
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"std"
));
&
g
,
CURAND_RNG_PSEUDO_DEFAULT
)
);
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
PADDLE_ENFORCE
(
ssize_t
N
=
framework
::
product
(
tensor
->
dims
());
platform
::
dynload
::
curandSetPseudoRandomGeneratorSeed
(
g
,
seed
));
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
N
,
platform
::
dynload
::
curandGenerateNormal
(
thrust
::
device_ptr
<
T
>
(
data
),
g
,
data
,
framework
::
product
(
tensor
->
dims
()),
mean
,
std
);
GaussianGenerator
<
T
>
(
mean
,
std
,
seed
)
);
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
gaussian_random
,
REGISTER_OP_GPU_KERNEL
(
gaussian_random
,
ops
::
GaussianRandomKernel
<
float
>
);
paddle
::
operators
::
GPU
GaussianRandomKernel
<
float
>
);
paddle/operators/mul_op.cc
浏览文件 @
1795e576
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
limitations under the License. */
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/operators/net_op.h
浏览文件 @
1795e576
...
@@ -84,13 +84,14 @@ class NetOp : public framework::OperatorBase {
...
@@ -84,13 +84,14 @@ class NetOp : public framework::OperatorBase {
return
true
;
return
true
;
}
}
void
A
ddOp
(
const
framework
::
OperatorBase
&
op
)
{
Ad
dOp
(
op
.
Clone
());
}
void
A
ppendOp
(
const
framework
::
OperatorBase
&
op
)
{
Appen
dOp
(
op
.
Clone
());
}
/**
/**
* @brief Add an operator by ptr
* @brief Add an operator by ptr
*/
*/
void
AddOp
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
op
)
{
void
AppendOp
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
op
)
{
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot AddOp when this network is sealed"
);
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot AppendOp when this network is sealed"
);
PADDLE_ENFORCE_NOT_NULL
(
op
,
"Cannot Insert Null op"
);
PADDLE_ENFORCE_NOT_NULL
(
op
,
"Cannot Insert Null op"
);
ops_
.
push_back
(
std
::
move
(
op
));
ops_
.
push_back
(
std
::
move
(
op
));
}
}
...
...
paddle/operators/net_op_test.cc
浏览文件 @
1795e576
...
@@ -38,10 +38,10 @@ TEST(OpKernel, all) {
...
@@ -38,10 +38,10 @@ TEST(OpKernel, all) {
auto
net
=
std
::
make_shared
<
NetOp
>
();
auto
net
=
std
::
make_shared
<
NetOp
>
();
ASSERT_NE
(
net
,
nullptr
);
ASSERT_NE
(
net
,
nullptr
);
net
->
A
d
dOp
(
std
::
unique_ptr
<
TestOp
>
(
net
->
A
ppen
dOp
(
std
::
unique_ptr
<
TestOp
>
(
new
TestOp
(
"test"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
new
TestOp
(
"test"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
{{
"Out"
,
{
"y"
}}},
{})));
{{
"Out"
,
{
"y"
}}},
{})));
net
->
A
d
dOp
(
std
::
unique_ptr
<
TestOp
>
(
net
->
A
ppen
dOp
(
std
::
unique_ptr
<
TestOp
>
(
new
TestOp
(
"test"
,
{{
"X"
,
{
"y"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}},
new
TestOp
(
"test"
,
{{
"X"
,
{
"y"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}},
{{
"Out"
,
{
"z"
}}},
{})));
{{
"Out"
,
{
"z"
}}},
{})));
...
@@ -61,7 +61,7 @@ TEST(NetOp, insert_op) {
...
@@ -61,7 +61,7 @@ TEST(NetOp, insert_op) {
auto
op1
=
std
::
unique_ptr
<
framework
::
NOP
>
(
auto
op1
=
std
::
unique_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
(
"empty"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
new
framework
::
NOP
(
"empty"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
{{
"Out"
,
{
"y"
}}},
{}));
{{
"Out"
,
{
"y"
}}},
{}));
net
.
A
d
dOp
(
*
op1
);
net
.
A
ppen
dOp
(
*
op1
);
net
.
InsertOp
(
0
,
*
op1
);
net
.
InsertOp
(
0
,
*
op1
);
ASSERT_EQ
(
2UL
,
net
.
ops_
.
size
());
ASSERT_EQ
(
2UL
,
net
.
ops_
.
size
());
net
.
InsertOp
(
2
,
std
::
move
(
op1
));
net
.
InsertOp
(
2
,
std
::
move
(
op1
));
...
@@ -70,9 +70,9 @@ TEST(NetOp, insert_op) {
...
@@ -70,9 +70,9 @@ TEST(NetOp, insert_op) {
TEST
(
NetOp
,
Clone
)
{
TEST
(
NetOp
,
Clone
)
{
NetOp
net
;
NetOp
net
;
net
.
A
d
dOp
(
net
.
A
ppen
dOp
(
std
::
unique_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
{
"empty"
,
{},
{},
{}}));
std
::
unique_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
{
"empty"
,
{},
{},
{}}));
net
.
A
d
dOp
(
std
::
unique_ptr
<
framework
::
NOP
>
(
net
.
A
ppen
dOp
(
std
::
unique_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
{
"empty2"
,
{},
{},
{}}));
new
framework
::
NOP
{
"empty2"
,
{},
{},
{}}));
net
.
CompleteAddOp
(
true
);
net
.
CompleteAddOp
(
true
);
auto
new_net_op
=
net
.
Clone
();
auto
new_net_op
=
net
.
Clone
();
...
...
paddle/operators/rowwise_add_op.h
浏览文件 @
1795e576
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/eigen.h"
...
@@ -63,7 +63,7 @@ class RowwiseAddGradKernel : public framework::OpKernel {
...
@@ -63,7 +63,7 @@ class RowwiseAddGradKernel : public framework::OpKernel {
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add
// colwise add
Eigen
::
array
<
int
,
1
>
dims
{{
1
}};
/* dimension to reduce */
Eigen
::
array
<
int
,
1
>
dims
{{
0
}};
/* dimension to reduce */
EigenVector
<
T
>::
Flatten
(
*
db
).
device
(
place
)
=
OutGrad
.
sum
(
dims
);
EigenVector
<
T
>::
Flatten
(
*
db
).
device
(
place
)
=
OutGrad
.
sum
(
dims
);
}
}
};
};
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
1795e576
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel {
...
@@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel {
std
::
uniform_real_distribution
<
T
>
dist
(
std
::
uniform_real_distribution
<
T
>
dist
(
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"min"
)),
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"min"
)),
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"max"
)));
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"max"
)));
for
(
ssize_t
i
=
0
;
i
<
framework
::
product
(
tensor
->
dims
());
++
i
)
{
ssize_t
size
=
framework
::
product
(
tensor
->
dims
());
for
(
ssize_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
data
[
i
]
=
dist
(
engine
);
}
}
}
}
...
@@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddOutput
(
"Out"
,
"The output tensor of uniform random op"
);
AddOutput
(
"Out"
,
"The output tensor of uniform random op"
);
AddComment
(
R"DOC(Uniform random operator.
AddComment
(
R"DOC(Uniform random operator.
Used to initialize tensor with uniform random generator.
Used to initialize tensor with uniform random generator.
)DOC"
);
)DOC"
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"the dimension of random tensor"
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"the dimension of random tensor"
);
...
...
paddle/operators/uniform_random_op.cu
浏览文件 @
1795e576
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
paddle/parameter/Parameter.h
浏览文件 @
1795e576
...
@@ -65,8 +65,11 @@ public:
...
@@ -65,8 +65,11 @@ public:
size_t
getSize
()
const
{
return
config_
.
size
();
}
size_t
getSize
()
const
{
return
config_
.
size
();
}
bool
isFullSize
()
const
{
bool
isFullSize
()
const
{
if
(
bufs_
[
PARAMETER_VALUE
])
{
return
this
->
getSize
()
==
bufs_
[
PARAMETER_VALUE
]
->
getSize
();
return
this
->
getSize
()
==
bufs_
[
PARAMETER_VALUE
]
->
getSize
();
}
}
return
false
;
}
inline
bool
useGpu
()
const
{
return
useGpu_
;
}
inline
bool
useGpu
()
const
{
return
useGpu_
;
}
...
...
paddle/platform/device_context.cc
浏览文件 @
1795e576
...
@@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() {
...
@@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() {
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
}
if
(
curand_generator_
)
{
PADDLE_ENFORCE
(
dynload
::
curandDestroyGenerator
(
curand_generator_
));
}
eigen_stream_
.
reset
();
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
@@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
...
@@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
cudaStream_t
CUDADeviceContext
::
stream
()
{
return
stream_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
{
return
stream_
;
}
curandGenerator_t
CUDADeviceContext
::
curand_generator
()
{
if
(
!
curand_generator_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
curandCreateGenerator
(
&
curand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
seed_
));
PADDLE_ENFORCE
(
dynload
::
curandSetStream
(
curand_generator_
,
stream_
));
}
return
curand_generator_
;
}
#endif // PADDLE_ONLY_CPU
#endif // PADDLE_ONLY_CPU
}
// namespace platform
}
// namespace platform
...
...
paddle/platform/device_context.h
浏览文件 @
1795e576
...
@@ -17,7 +17,6 @@ limitations under the License. */
...
@@ -17,7 +17,6 @@ limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#endif
#endif
...
@@ -40,7 +39,7 @@ class DeviceContext {
...
@@ -40,7 +39,7 @@ class DeviceContext {
class
CPUDeviceContext
:
public
DeviceContext
{
class
CPUDeviceContext
:
public
DeviceContext
{
public:
public:
CPUDeviceContext
();
CPUDeviceContext
();
explicit
CPUDeviceContext
(
CPUPlace
);
explicit
CPUDeviceContext
(
CPUPlace
place
);
virtual
~
CPUDeviceContext
()
{}
virtual
~
CPUDeviceContext
()
{}
Eigen
::
DefaultDevice
*
eigen_device
()
const
;
Eigen
::
DefaultDevice
*
eigen_device
()
const
;
...
@@ -56,7 +55,7 @@ class EigenCudaStreamDevice;
...
@@ -56,7 +55,7 @@ class EigenCudaStreamDevice;
class
CUDADeviceContext
:
public
DeviceContext
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
public:
explicit
CUDADeviceContext
(
GPUPlace
);
explicit
CUDADeviceContext
(
GPUPlace
place
);
virtual
~
CUDADeviceContext
();
virtual
~
CUDADeviceContext
();
/*! \brief Wait for all operations completion in the stream. */
/*! \brief Wait for all operations completion in the stream. */
...
@@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
();
cudnnHandle_t
cudnn_handle
();
/*! \brief Return curand handle in the device context. */
curandGenerator_t
curand_generator
();
/*! \brief Return cuda stream in the device context. */
/*! \brief Return cuda stream in the device context. */
cudaStream_t
stream
();
cudaStream_t
stream
();
// clang-format on
// clang-format on
...
@@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext {
private:
private:
GPUPlace
place_
;
GPUPlace
place_
;
private:
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
private:
uint64_t
seed_
;
// clang-format off
// clang-format off
cudaStream_t
stream_
{
nullptr
};
cudaStream_t
stream_
{
nullptr
};
cudnnHandle_t
cudnn_handle_
{
nullptr
};
cudnnHandle_t
cudnn_handle_
{
nullptr
};
cublasHandle_t
cublas_handle_
{
nullptr
};
cublasHandle_t
cublas_handle_
{
nullptr
};
curandGenerator_t
curand_generator_
{
nullptr
};
// clang-format on
// clang-format on
};
};
...
...
paddle/platform/device_context_test.cc
浏览文件 @
1795e576
...
@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) {
...
@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE
(
nullptr
,
cudnn_handle
);
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
ASSERT_NE
(
nullptr
,
cublas_handle
);
curandGenerator_t
curand_handle
=
device_context
->
curand_generator
();
ASSERT_NE
(
nullptr
,
curand_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
delete
device_context
;
delete
device_context
;
}
}
...
...
paddle/pserver/ParameterClient2.cpp
浏览文件 @
1795e576
...
@@ -65,7 +65,6 @@ void ParameterClient2::initThreads() {
...
@@ -65,7 +65,6 @@ void ParameterClient2::initThreads() {
LOG
(
INFO
)
<<
"parallel_thread_num dosent need to set"
;
LOG
(
INFO
)
<<
"parallel_thread_num dosent need to set"
;
}
}
syncThreadPool_
.
reset
(
new
SyncThreadPool
(
threadNum_
));
syncThreadPool_
.
reset
(
new
SyncThreadPool
(
threadNum_
));
startThreads
();
startThreads
();
}
}
...
@@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData(
...
@@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData(
request
.
set_cost
(
cost
);
request
.
set_cost
(
cost
);
request
.
set_batch_status
(
batchStatus
);
request
.
set_batch_status
(
batchStatus
);
CHECK_EQ
(
request
.
blocks_size
(),
0
);
CHECK_EQ
(
request
.
blocks_size
(),
0
);
VLOG
(
10
)
<<
"request: trainer_id: "
<<
request
.
trainer_id
()
<<
" update_mode"
<<
request
.
update_mode
()
<<
" send_back_parameter: "
<<
request
.
send_back_parameter
()
<<
" send_back_parameter_type: "
<<
request
.
send_back_parameter_type
()
<<
" num_samples: "
<<
request
.
num_samples
()
<<
" cost: "
<<
request
.
cost
()
<<
" batch_status: "
<<
request
.
batch_status
();
}
}
for
(
const
auto
&
segments
:
parameterSegments
)
{
for
(
const
auto
&
segments
:
parameterSegments
)
{
const
auto
it
=
parameterMap_
.
find
(
segments
.
id
);
const
auto
it
=
parameterMap_
.
find
(
segments
.
id
);
...
@@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData(
...
@@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData(
CHECK
(
sendMat
!=
nullptr
)
<<
"sendMat is nullptr"
;
CHECK
(
sendMat
!=
nullptr
)
<<
"sendMat is nullptr"
;
syncThreadPool_
->
exec
([
&
](
int
tid
,
size_t
numThreads
)
{
syncThreadPool_
->
exec
([
&
](
int
tid
,
size_t
numThreads
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
sparseAutoGrowthMutex_
);
const
auto
&
localIndices
=
prefetchMat
->
getLocalIndices
();
const
auto
&
localIndices
=
prefetchMat
->
getLocalIndices
();
/// num of sparse rows
/// num of sparse rows
size_t
nLocalBlocks
=
localIndices
.
size
();
size_t
nLocalBlocks
=
localIndices
.
size
();
uint64_t
beginDim
=
0
;
uint64_t
beginDim
=
0
;
uint64_t
endDim
=
0
;
uint64_t
endDim
=
0
;
// FIXME(typhoonzero): let it resize first
prefetchMat
->
getLocalRow
(
nLocalBlocks
+
1
);
sendMat
->
getLocalRow
(
nLocalBlocks
+
1
);
for
(
size_t
row
=
0
;
row
<
nLocalBlocks
;
++
row
)
{
for
(
size_t
row
=
0
;
row
<
nLocalBlocks
;
++
row
)
{
int64_t
blockId
=
localIndices
[
row
];
// local row -> sparse row
int64_t
blockId
=
localIndices
[
row
];
// local row -> sparse row
int
serverId
=
std
::
abs
((
blockId
+
nameHash
)
%
serviceNum_
);
int
serverId
=
std
::
abs
((
blockId
+
nameHash
)
%
serviceNum_
);
...
@@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData(
...
@@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData(
block
->
set_begin_pos
(
row
*
blockSize
);
block
->
set_begin_pos
(
row
*
blockSize
);
/// block len
/// block len
block
->
set_block_size
(
endDim
-
beginDim
);
block
->
set_block_size
(
endDim
-
beginDim
);
if
(
sendingPara
)
{
if
(
sendingPara
)
{
sendJob
->
parallelInputIovs
[
serverId
].
push_back
(
sendJob
->
parallelInputIovs
[
serverId
].
push_back
(
{
sendMat
->
getLocalRow
(
row
),
sizeof
(
real
)
*
(
size_t
)
blockSize
});
{
sendMat
->
getLocalRow
(
row
),
sizeof
(
real
)
*
(
size_t
)
blockSize
});
...
...
paddle/pserver/ParameterClient2.h
浏览文件 @
1795e576
...
@@ -583,6 +583,7 @@ protected:
...
@@ -583,6 +583,7 @@ protected:
#ifndef PADDLE_DISABLE_TIMER
#ifndef PADDLE_DISABLE_TIMER
uint64_t
forwardbackwordTime_
;
uint64_t
forwardbackwordTime_
;
#endif
#endif
std
::
mutex
sparseAutoGrowthMutex_
;
/// map id to parameter used for decoding protobuf data
/// map id to parameter used for decoding protobuf data
std
::
unordered_map
<
size_t
,
ParameterPtr
>
parameterMap_
;
std
::
unordered_map
<
size_t
,
ParameterPtr
>
parameterMap_
;
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
1795e576
...
@@ -2232,6 +2232,20 @@ class ClipLayer(LayerBase):
...
@@ -2232,6 +2232,20 @@ class ClipLayer(LayerBase):
self
.
config
.
inputs
[
0
].
clip_conf
.
max
=
max
self
.
config
.
inputs
[
0
].
clip_conf
.
max
=
max
@
config_layer
(
'scale_shift'
)
class
ScaleShiftLayer
(
LayerBase
):
def
__init__
(
self
,
name
,
inputs
,
bias
=
True
,
**
xargs
):
super
(
ScaleShiftLayer
,
self
).
__init__
(
name
,
'scale_shift'
,
0
,
inputs
=
inputs
,
**
xargs
)
config_assert
(
len
(
self
.
inputs
)
==
1
,
'ScaleShiftLayer must have one and only one input.'
)
input_layer
=
self
.
get_input_layer
(
0
)
self
.
set_layer_size
(
input_layer
.
size
)
self
.
create_input_parameter
(
0
,
1
,
[
1
,
1
])
self
.
create_bias_parameter
(
bias
,
1
)
# key: cost type
# key: cost type
# value: cost class
# value: cost class
g_cost_map
=
{}
g_cost_map
=
{}
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
1795e576
...
@@ -133,6 +133,7 @@ __all__ = [
...
@@ -133,6 +133,7 @@ __all__ = [
'clip_layer'
,
'clip_layer'
,
'slice_projection'
,
'slice_projection'
,
'kmax_sequence_score_layer'
,
'kmax_sequence_score_layer'
,
'scale_shift_layer'
,
]
]
...
@@ -230,6 +231,7 @@ class LayerType(object):
...
@@ -230,6 +231,7 @@ class LayerType(object):
CLIP_LAYER
=
'clip'
CLIP_LAYER
=
'clip'
KMAX_SEQ_SCORE
=
'kmax_seq_score'
KMAX_SEQ_SCORE
=
'kmax_seq_score'
SCALE_SHIFT_LAYER
=
'scale_shift'
@
staticmethod
@
staticmethod
def
is_layer_type
(
type_name
):
def
is_layer_type
(
type_name
):
...
@@ -6210,3 +6212,43 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1):
...
@@ -6210,3 +6212,43 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1):
return
LayerOutput
(
return
LayerOutput
(
name
,
LayerType
.
KMAX_SEQ_SCORE
,
parents
=
[
input
],
size
=
input
.
size
)
name
,
LayerType
.
KMAX_SEQ_SCORE
,
parents
=
[
input
],
size
=
input
.
size
)
@
wrap_name_default
(
"scale_shift"
)
@
wrap_param_attr_default
()
@
wrap_bias_attr_default
()
def
scale_shift_layer
(
input
,
name
=
None
,
param_attr
=
None
,
bias_attr
=
None
):
"""
A layer applies a linear transformation to each element in each row of
the input matrix. For each element, the layer first re-scale it and then
adds a bias to it.
This layer is very like the SlopeInterceptLayer, except the scale and
bias are trainable.
.. math::
y = w * x + b
.. code-block:: python
scale_shift = scale_shift_layer(input=input_layer, bias_attr=False)
:param name: The Layer Name.
:type name: basestring
:param input: The input layer.
:type input: LayerOutput.
:param param_attr: The parameter attribute of scaling.
:type param_attr: ParameterAttribute
:param bias_attr: The parameter attribute of shifting.
:type bias_attr: ParameterAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
Layer
(
name
=
name
,
type
=
LayerType
.
SCALE_SHIFT_LAYER
,
inputs
=
Input
(
input
.
name
,
**
param_attr
.
attr
),
bias
=
ParamAttr
.
to_bias
(
bias_attr
))
return
LayerOutput
(
name
,
LayerType
.
SCALE_SHIFT_LAYER
,
parents
=
[
input
],
size
=
input
.
size
)
python/paddle/trainer_config_helpers/tests/configs/file_list.sh
浏览文件 @
1795e576
...
@@ -8,6 +8,6 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
...
@@ -8,6 +8,6 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_seq_select_layers
)
test_kmax_seq_socre_layer test_seq_select_layers
test_scale_shift_layer
)
export
whole_configs
=(
test_split_datasource
)
export
whole_configs
=(
test_split_datasource
)
python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr
0 → 100644
浏览文件 @
1795e576
type: "nn"
layers {
name: "data"
type: "data"
size: 100
active_type: ""
}
layers {
name: "__scale_shift_0__"
type: "scale_shift"
size: 100
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "___scale_shift_0__.w0"
}
}
layers {
name: "__scale_shift_1__"
type: "scale_shift"
size: 100
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "___scale_shift_1__.w0"
}
bias_parameter_name: "___scale_shift_1__.wbias"
}
parameters {
name: "___scale_shift_0__.w0"
size: 1
initial_mean: 0.0
initial_std: 1.0
dims: 1
dims: 1
initial_strategy: 0
initial_smart: true
}
parameters {
name: "___scale_shift_1__.w0"
size: 1
initial_mean: 0.0
initial_std: 1.0
dims: 1
dims: 1
initial_strategy: 0
initial_smart: true
}
parameters {
name: "___scale_shift_1__.wbias"
size: 1
initial_mean: 0.0
initial_std: 0.0
dims: 1
dims: 1
initial_strategy: 0
initial_smart: false
}
input_layer_names: "data"
output_layer_names: "__scale_shift_0__"
output_layer_names: "__scale_shift_1__"
sub_models {
name: "root"
layer_names: "data"
layer_names: "__scale_shift_0__"
layer_names: "__scale_shift_1__"
input_layer_names: "data"
output_layer_names: "__scale_shift_0__"
output_layer_names: "__scale_shift_1__"
is_recurrent_layer_group: false
}
python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py
0 → 100644
浏览文件 @
1795e576
from
paddle.trainer_config_helpers
import
*
data
=
data_layer
(
name
=
'data'
,
size
=
100
)
scale
=
scale_shift_layer
(
input
=
data
,
bias_attr
=
False
)
scale_shift
=
scale_shift_layer
(
input
=
data
)
outputs
(
scale
,
scale_shift
)
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
1795e576
...
@@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
...
@@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test
(
test_default_scope_funcs SRCS test_default_scope_funcs.py
)
py_test
(
test_default_scope_funcs SRCS test_default_scope_funcs.py
)
py_test
(
test_operator SRCS test_operator.py
)
py_test
(
test_operator SRCS test_operator.py
)
#
py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test
(
test_gaussian_random_op SRCS test_gaussian_random_op.py
)
py_test
(
test_uniform_random_op SRCS test_uniform_random_op.py
)
py_test
(
test_uniform_random_op SRCS test_uniform_random_op.py
)
py_test
(
test_recurrent_op SRCS test_recurrent_op.py
)
py_test
(
test_recurrent_op SRCS test_recurrent_op.py
)
py_test
(
test_sgd_op SRCS test_sgd_op.py
)
py_test
(
test_sgd_op SRCS test_sgd_op.py
)
...
...
python/paddle/v2/framework/tests/test_net.py
浏览文件 @
1795e576
...
@@ -6,8 +6,8 @@ import unittest
...
@@ -6,8 +6,8 @@ import unittest
def
fc
(
X
,
W
,
Y
):
def
fc
(
X
,
W
,
Y
):
ret_v
=
core
.
Net
.
create
()
ret_v
=
core
.
Net
.
create
()
ret_v
.
a
d
d_op
(
Operator
(
"mul"
,
X
=
"X"
,
Y
=
"W"
,
Out
=
"pre_activation"
))
ret_v
.
a
ppen
d_op
(
Operator
(
"mul"
,
X
=
"X"
,
Y
=
"W"
,
Out
=
"pre_activation"
))
ret_v
.
a
d
d_op
(
Operator
(
"sigmoid"
,
X
=
"pre_activation"
,
Y
=
Y
))
ret_v
.
a
ppen
d_op
(
Operator
(
"sigmoid"
,
X
=
"pre_activation"
,
Y
=
Y
))
ret_v
.
complete_add_op
(
True
)
ret_v
.
complete_add_op
(
True
)
return
ret_v
return
ret_v
...
@@ -16,12 +16,12 @@ class TestNet(unittest.TestCase):
...
@@ -16,12 +16,12 @@ class TestNet(unittest.TestCase):
def
test_net_all
(
self
):
def
test_net_all
(
self
):
net
=
core
.
Net
.
create
()
net
=
core
.
Net
.
create
()
op1
=
Operator
(
"add_two"
,
X
=
"X"
,
Y
=
"Y"
,
Out
=
"Out"
)
op1
=
Operator
(
"add_two"
,
X
=
"X"
,
Y
=
"Y"
,
Out
=
"Out"
)
net
.
a
d
d_op
(
op1
)
net
.
a
ppen
d_op
(
op1
)
net2
=
core
.
Net
.
create
()
net2
=
core
.
Net
.
create
()
net2
.
a
d
d_op
(
fc
(
X
=
"X"
,
W
=
"w"
,
Y
=
"fc.out"
))
net2
.
a
ppen
d_op
(
fc
(
X
=
"X"
,
W
=
"w"
,
Y
=
"fc.out"
))
net2
.
complete_add_op
(
True
)
net2
.
complete_add_op
(
True
)
net
.
a
d
d_op
(
net2
)
net
.
a
ppen
d_op
(
net2
)
net
.
complete_add_op
(
True
)
net
.
complete_add_op
(
True
)
expected
=
'''
expected
=
'''
...
...
python/paddle/v2/framework/tests/test_recurrent_op.py
浏览文件 @
1795e576
...
@@ -150,7 +150,7 @@ class TestRecurrentOp(unittest.TestCase):
...
@@ -150,7 +150,7 @@ class TestRecurrentOp(unittest.TestCase):
sig_op
=
Operator
(
"sigmoid"
,
X
=
"sum"
,
Y
=
"h@alias"
)
sig_op
=
Operator
(
"sigmoid"
,
X
=
"sum"
,
Y
=
"h@alias"
)
for
op
in
[
x_fc_op
,
h_fc_op
,
sum_op
,
sig_op
]:
for
op
in
[
x_fc_op
,
h_fc_op
,
sum_op
,
sig_op
]:
stepnet
.
a
d
d_op
(
op
)
stepnet
.
a
ppen
d_op
(
op
)
stepnet
.
complete_add_op
(
True
)
stepnet
.
complete_add_op
(
True
)
self
.
rnnop
.
set_stepnet
(
stepnet
)
self
.
rnnop
.
set_stepnet
(
stepnet
)
...
...
python/paddle/v2/framework/tests/test_rowwise_add_op.py
浏览文件 @
1795e576
...
@@ -20,7 +20,7 @@ class RowwiseAddGradOpTest(GradientChecker):
...
@@ -20,7 +20,7 @@ class RowwiseAddGradOpTest(GradientChecker):
def
test_rowwise_add
(
self
):
def
test_rowwise_add
(
self
):
op
=
create_op
(
"rowwise_add"
)
op
=
create_op
(
"rowwise_add"
)
inputs
=
{
inputs
=
{
"X"
:
np
.
random
.
uniform
(
0.1
,
1
,
[
10
,
10
]).
astype
(
"float32"
),
"X"
:
np
.
random
.
uniform
(
0.1
,
1
,
[
5
,
10
]).
astype
(
"float32"
),
"b"
:
np
.
random
.
uniform
(
0.1
,
1
,
[
10
]).
astype
(
"float32"
)
"b"
:
np
.
random
.
uniform
(
0.1
,
1
,
[
10
]).
astype
(
"float32"
)
}
}
self
.
check_grad
(
op
,
inputs
,
set
([
"X"
,
"b"
]),
"Out"
)
self
.
check_grad
(
op
,
inputs
,
set
([
"X"
,
"b"
]),
"Out"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录