Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0a0f1948
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a0f1948
编写于
8月 22, 2017
作者:
Z
zchen0211
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into develop
上级
dc5f0dbc
5810d63f
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
250 addition
and
140 deletion
+250
-140
paddle/framework/backward.cc
paddle/framework/backward.cc
+5
-4
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+15
-15
paddle/framework/pybind.cc
paddle/framework/pybind.cc
+3
-3
paddle/gserver/gradientmachines/NeuralNetwork.cpp
paddle/gserver/gradientmachines/NeuralNetwork.cpp
+1
-1
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+6
-9
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+117
-5
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+11
-3
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+13
-16
paddle/operators/gaussian_random_op.cu
paddle/operators/gaussian_random_op.cu
+35
-23
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+0
-1
paddle/operators/net_op.h
paddle/operators/net_op.h
+4
-3
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+5
-5
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+2
-5
paddle/operators/uniform_random_op.cu
paddle/operators/uniform_random_op.cu
+0
-3
paddle/parameter/Parameter.h
paddle/parameter/Parameter.h
+4
-1
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+0
-16
paddle/platform/device_context.h
paddle/platform/device_context.h
+2
-11
paddle/platform/device_context_test.cc
paddle/platform/device_context_test.cc
+0
-2
paddle/pserver/ParameterClient2.cpp
paddle/pserver/ParameterClient2.cpp
+14
-2
paddle/pserver/ParameterClient2.h
paddle/pserver/ParameterClient2.h
+1
-0
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+1
-1
python/paddle/v2/framework/tests/op_test_util.py
python/paddle/v2/framework/tests/op_test_util.py
+2
-1
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+3
-4
python/paddle/v2/framework/tests/test_net.py
python/paddle/v2/framework/tests/test_net.py
+5
-5
python/paddle/v2/framework/tests/test_recurrent_op.py
python/paddle/v2/framework/tests/test_recurrent_op.py
+1
-1
未找到文件。
paddle/framework/backward.cc
浏览文件 @
0a0f1948
...
...
@@ -110,7 +110,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
return
false
;
});
net
->
A
d
dOp
(
std
::
move
(
bwd
));
net
->
A
ppen
dOp
(
std
::
move
(
bwd
));
}
// Get unique ID for this method.
auto
uid
=
uniq_id
++
;
...
...
@@ -163,8 +163,9 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{{
"Src"
,
{
prefix
}}},
{{
"Dst"
,
{
grad_input
}}},
{}));
net
->
AppendOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{{
"Src"
,
{
prefix
}}},
{{
"Dst"
,
{
grad_input
}}},
{}));
}
return
false
;
});
...
...
@@ -195,7 +196,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
return
grad_op
;
}
net
->
A
d
dOp
(
std
::
move
(
grad_op
));
net
->
A
ppen
dOp
(
std
::
move
(
grad_op
));
}
net
->
SetType
(
"@GENERATED_BACKWARD@"
);
net
->
CompleteAddOp
();
...
...
paddle/framework/backward_test.cc
浏览文件 @
0a0f1948
...
...
@@ -75,13 +75,13 @@ class FcOp : public operators::NetOp {
FcOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
A
d
dOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
{{
"Out"
,
{
Output
(
"mul_result"
)}}},
{}));
A
ppen
dOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
{{
"Out"
,
{
Output
(
"mul_result"
)}}},
{}));
auto
input_b
=
Inputs
(
"b"
);
std
::
string
before_act
=
"mul_result"
;
if
(
input_b
.
size
()
!=
0
)
{
A
d
dOp
(
OpRegistry
::
CreateOp
(
A
ppen
dOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{{
"X"
,
{
Output
(
"mul_result"
)}},
{
"b"
,
{
input_b
[
0
]}}},
{{
"Out"
,
{
Output
(
"add_result"
)}}},
{}));
before_act
=
"add_result"
;
...
...
@@ -92,8 +92,8 @@ class FcOp : public operators::NetOp {
}
}
A
d
dOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{{
"X"
,
{
Output
(
before_act
)}}},
{{
"Out"
,
{
Output
(
"Out"
)}}},
{}));
A
ppen
dOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{{
"X"
,
{
Output
(
before_act
)}}},
{{
"Out"
,
{
Output
(
"Out"
)}}},
{}));
CompleteAddOp
(
false
);
}
};
...
...
@@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) {
TEST
(
Backward
,
net_input_of_network_not_need_grad
)
{
ops
::
NetOp
net
;
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"W1"
}},
{
"b"
,
{
"b1"
}}},
{{
"mul_result"
,
{
"mul_tmp_0"
}},
{
"add_result"
,
{
"add_tmp_0"
}},
{
"Out"
,
{
"hidden0"
}}},
{}));
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"hidden0"
}},
{
"W"
,
{
"W2"
}},
{
"b"
,
{
"b2"
}}},
{{
"mul_result"
,
{
"mul_tmp_1"
}},
{
"add_result"
,
{
"add_tmp_1"
}},
...
...
@@ -273,10 +273,10 @@ TEST(Backward, net_input_of_network_not_need_grad) {
TEST
(
Backward
,
net_shared_weight
)
{
ops
::
NetOp
net
;
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"w"
}}},
{{
"Out"
,
{
"out"
}}},
{}));
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"out"
}},
{
"Y"
,
{
"w"
}}},
{{
"Out"
,
{
"FinalOut"
}}},
{}));
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"w"
}}},
{{
"Out"
,
{
"out"
}}},
{}));
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"out"
}},
{
"Y"
,
{
"w"
}}},
{{
"Out"
,
{
"FinalOut"
}}},
{}));
net
.
CompleteAddOp
();
auto
bwd
=
f
::
Backward
(
net
,
{});
...
...
@@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) {
TEST
(
Backward
,
linear_net_intermediate_variable_has_no_grad
)
{
ops
::
NetOp
net
;
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x1"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
{{
"mul_result"
,
{
"mul_out1"
}},
{
"add_result"
,
{
"add_out1"
}},
{
"Out"
,
{
"out1"
}}},
{}));
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"out1"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}},
{{
"mul_result"
,
{
"mul_out2"
}},
{
"add_result"
,
{
"tmp_out2"
}},
{
"Out"
,
{
"out2"
}}},
{}));
net
.
A
d
dOp
(
f
::
OpRegistry
::
CreateOp
(
net
.
A
ppen
dOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"out2"
}},
{
"W"
,
{
"w3"
}},
{
"b"
,
{
"b3"
}}},
{{
"mul_result"
,
{
"mul_out3"
}},
{
"add_result"
,
{
"tmp_out3"
}},
...
...
paddle/framework/pybind.cc
浏览文件 @
0a0f1948
...
...
@@ -31,7 +31,7 @@ limitations under the License. */
namespace
py
=
pybind11
;
USE_OP
(
add_two
);
USE_
CPU_ONLY_
OP
(
onehot_cross_entropy
);
USE_OP
(
onehot_cross_entropy
);
USE_OP
(
sgd
);
USE_OP
(
mul
);
USE_OP
(
mean
);
...
...
@@ -223,8 +223,8 @@ All parameter, weight, gradient are variables in Paddle.
retv
->
SetType
(
"plain_net"
);
return
retv
;
})
.
def
(
"a
d
d_op"
,
[](
operators
::
NetOp
&
self
,
const
OperatorBase
&
op
)
{
self
.
Ad
dOp
(
op
);
})
.
def
(
"a
ppen
d_op"
,
[](
operators
::
NetOp
&
self
,
const
OperatorBase
&
op
)
{
self
.
Appen
dOp
(
op
);
})
.
def
(
"complete_add_op"
,
&
operators
::
NetOp
::
CompleteAddOp
)
.
def
(
"complete_add_op"
,
[](
std
::
shared_ptr
<
operators
::
NetOp
>
&
self
)
{
self
->
CompleteAddOp
();
...
...
paddle/gserver/gradientmachines/NeuralNetwork.cpp
浏览文件 @
0a0f1948
...
...
@@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector<Argument>& inArgs) {
auto
mat
=
dynamic_cast
<
SparsePrefetchRowCpuMatrix
*>
(
para
->
getMat
(
PARAMETER_VALUE
).
get
());
para
->
clearGradient
();
mat
->
clearIndices
();
if
(
mat
)
mat
->
clearIndices
();
}
}
}
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
0a0f1948
...
...
@@ -39,11 +39,10 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
X_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
// TODO(superjom) add enforce here after helper functions ready
X_grad
->
Resize
(
X
->
dims
());
dX
->
Resize
(
X
->
dims
());
}
};
...
...
@@ -70,9 +69,7 @@ namespace ops = paddle::operators;
REGISTER_OP
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOp
,
ops
::
OnehotCrossEntropyOpMaker
,
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOp
);
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOpKernel
<
float
>
);
paddle/operators/cross_entropy_op.cu
浏览文件 @
0a0f1948
...
...
@@ -12,10 +12,122 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/assert.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
__host__
__device__
T
clipping_log
(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
T
v
=
log
(
x
);
if
(
v
==
INFINITY
)
{
return
kApproInf
;
}
if
(
v
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
v
;
}
template
<
typename
T
>
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
int
*
label
,
const
int
N
,
const
int
D
)
{
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
Y
[
i
]
=
-
clipping_log
(
X
[
i
*
D
+
label
[
i
]]);
}
}
// TODO(qingqing): make zero setting an common function.
template
<
typename
T
>
__global__
void
zero
(
T
*
X
,
const
int
N
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
X
[
i
]
=
0.0
;
}
}
template
<
typename
T
>
__global__
void
CrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
const
int
*
label
,
const
int
N
,
const
int
D
)
{
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
idx
=
i
*
D
+
label
[
i
];
dX
[
idx
]
=
-
dY
[
i
]
/
X
[
idx
];
}
}
template
<
typename
T
>
class
OnehotCrossEntropyOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use GPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
auto
Y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
Ydata
=
Y
->
data
<
T
>
();
int
N
=
X
->
dims
()[
0
];
int
D
=
X
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
+
block
-
1
)
/
block
;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
Ydata
,
Xdata
,
label_data
,
N
,
D
);
}
};
template
<
typename
T
>
class
OnehotCrossEntropyGradientOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use GPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dY
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"label"
);
auto
*
dXdata
=
dX
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
auto
*
dYdata
=
dY
->
template
data
<
T
>();
auto
*
Xdata
=
X
->
template
data
<
T
>();
auto
*
label_data
=
label
->
data
<
int
>
();
int
N
=
X
->
dims
()[
0
];
int
D
=
X
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
*
D
+
block
-
1
)
/
block
;
zero
<
T
><<<
grid
,
block
>>>
(
dXdata
,
N
*
D
);
grid
=
(
N
+
block
-
1
)
/
block
;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dXdata
,
dYdata
,
Xdata
,
label_data
,
N
,
D
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOpCUDAKernel
<
float
>
);
paddle/operators/cross_entropy_op.h
浏览文件 @
0a0f1948
...
...
@@ -21,7 +21,7 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
T
tolerable_value
(
T
x
)
{
inline
T
tolerable_value
(
const
T
x
)
{
static_assert
(
std
::
is_floating_point
<
T
>::
value
,
"tolerable_value works only on float, "
"double and double double."
);
...
...
@@ -39,10 +39,13 @@ T tolerable_value(T x) {
return
x
;
}
template
<
typename
Place
,
typename
T
>
template
<
typename
T
>
class
OnehotCrossEntropyOpKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
...
...
@@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel {
}
};
template
<
typename
Place
,
typename
T
>
template
<
typename
T
>
class
OnehotCrossEntropyGradientOpKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dY
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
...
...
@@ -79,6 +85,8 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
const
int
batch_size
=
X
->
dims
()[
0
];
const
int
class_num
=
X
->
dims
()[
1
];
// TODO(qingqing): make zero setting an common function.
memset
(
dXdata
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
dXdata
[
index
]
=
-
tolerable_value
(
dYdata
[
i
]
/
Xdata
[
index
]);
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
0a0f1948
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -19,25 +16,25 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
>
class
GaussianRandomKernel
:
public
framework
::
OpKernel
{
class
CPU
GaussianRandomKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
op_
.
GetAttr
<
float
>
(
"mean"
);
float
std
=
context
.
op_
.
GetAttr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// TODO(dzh): attribute does not support unsigned int.
// And we need a global random seed configuration.
int
seed
=
context
.
op_
.
GetAttr
<
int
>
(
"seed"
)
;
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
std
::
minstd_rand
engine
;
if
(
seed
==
0
)
{
seed
=
std
::
random_device
()();
}
std
::
mt19937
g
(
seed
);
std
::
normal_distribution
<
T
>
dist
ribution
(
mean
,
std
);
engine
.
seed
(
seed
);
std
::
normal_distribution
<
T
>
dist
(
mean
,
std
);
ssize_t
size
=
framework
::
product
(
tensor
->
dims
());
for
(
in
t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
ribution
(
g
);
for
(
ssize_
t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
}
}
};
...
...
@@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
dims
=
GetAttr
<
std
::
vector
<
int
>>
(
"dims"
);
PADDLE_ENFORCE
(
dims
.
size
()
>
0UL
,
"dims can be one int or array. dims must be set."
);
...
...
@@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator.
)DOC"
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"The dimension of random tensor."
);
AddAttr
<
float
>
(
"mean"
,
"mean
value of random
."
).
SetDefault
(
.0
f
);
AddAttr
<
float
>
(
"std"
,
"
minimum value of random value
."
).
SetDefault
(
1.0
f
);
AddAttr
<
float
>
(
"mean"
,
"mean
of random tensor
."
).
SetDefault
(
.0
f
);
AddAttr
<
float
>
(
"std"
,
"
std of random tensor
."
).
SetDefault
(
1.0
f
);
AddAttr
<
int
>
(
"seed"
,
"Random seed of generator."
"0 means use system wide seed"
)
...
...
@@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator.
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
gaussian_random
,
ops
::
GaussianRandomOp
,
ops
::
GaussianRandomOpMaker
);
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
ops
::
GaussianRandomKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
ops
::
CPU
GaussianRandomKernel
<
float
>
);
paddle/operators/gaussian_random_op.cu
浏览文件 @
0a0f1948
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <random>
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
GaussianRandomKernel
:
public
framework
::
OpKernel
{
struct
GaussianGenerator
{
T
mean_
,
std_
;
unsigned
int
seed_
;
__host__
__device__
GaussianGenerator
(
T
mean
,
T
std
,
int
seed
)
:
mean_
(
mean
),
std_
(
std
),
seed_
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
normal_distribution
<
T
>
dist
(
mean_
,
std_
);
rng
.
discard
(
n
);
return
dist
(
rng
);
}
};
template
<
typename
T
>
class
GPUGaussianRandomKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
op_
.
GetAttr
<
float
>
(
"mean"
);
float
std
=
context
.
op_
.
GetAttr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
seed
=
context
.
op_
.
GetAttr
<
int
>
(
"seed"
);
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
)
);
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
}
curandGenerator_t
g
;
PADDLE_ENFORCE
(
platform
::
dynload
::
curandCreateGenerator
(
&
g
,
CURAND_RNG_PSEUDO_DEFAULT
)
);
PADDLE_ENFORCE
(
platform
::
dynload
::
curandSetPseudoRandomGeneratorSeed
(
g
,
seed
));
platform
::
dynload
::
curandGenerateNormal
(
g
,
data
,
framework
::
product
(
tensor
->
dims
()),
mean
,
std
);
T
mean
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"mean"
))
;
T
std
=
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"std"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
ssize_t
N
=
framework
::
product
(
tensor
->
dims
());
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
N
,
thrust
::
device_ptr
<
T
>
(
data
),
GaussianGenerator
<
T
>
(
mean
,
std
,
seed
)
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
gaussian_random
,
ops
::
GaussianRandomKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
gaussian_random
,
paddle
::
operators
::
GPU
GaussianRandomKernel
<
float
>
);
paddle/operators/mul_op.cc
浏览文件 @
0a0f1948
...
...
@@ -13,7 +13,6 @@
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/operators/net_op.h
浏览文件 @
0a0f1948
...
...
@@ -84,13 +84,14 @@ class NetOp : public framework::OperatorBase {
return
true
;
}
void
A
ddOp
(
const
framework
::
OperatorBase
&
op
)
{
Ad
dOp
(
op
.
Clone
());
}
void
A
ppendOp
(
const
framework
::
OperatorBase
&
op
)
{
Appen
dOp
(
op
.
Clone
());
}
/**
* @brief Add an operator by ptr
*/
void
AddOp
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
op
)
{
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot AddOp when this network is sealed"
);
void
AppendOp
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
op
)
{
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot AppendOp when this network is sealed"
);
PADDLE_ENFORCE_NOT_NULL
(
op
,
"Cannot Insert Null op"
);
ops_
.
push_back
(
std
::
move
(
op
));
}
...
...
paddle/operators/net_op_test.cc
浏览文件 @
0a0f1948
...
...
@@ -38,10 +38,10 @@ TEST(OpKernel, all) {
auto
net
=
std
::
make_shared
<
NetOp
>
();
ASSERT_NE
(
net
,
nullptr
);
net
->
A
d
dOp
(
std
::
unique_ptr
<
TestOp
>
(
net
->
A
ppen
dOp
(
std
::
unique_ptr
<
TestOp
>
(
new
TestOp
(
"test"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
{{
"Out"
,
{
"y"
}}},
{})));
net
->
A
d
dOp
(
std
::
unique_ptr
<
TestOp
>
(
net
->
A
ppen
dOp
(
std
::
unique_ptr
<
TestOp
>
(
new
TestOp
(
"test"
,
{{
"X"
,
{
"y"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}},
{{
"Out"
,
{
"z"
}}},
{})));
...
...
@@ -61,7 +61,7 @@ TEST(NetOp, insert_op) {
auto
op1
=
std
::
unique_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
(
"empty"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
{{
"Out"
,
{
"y"
}}},
{}));
net
.
A
d
dOp
(
*
op1
);
net
.
A
ppen
dOp
(
*
op1
);
net
.
InsertOp
(
0
,
*
op1
);
ASSERT_EQ
(
2UL
,
net
.
ops_
.
size
());
net
.
InsertOp
(
2
,
std
::
move
(
op1
));
...
...
@@ -70,9 +70,9 @@ TEST(NetOp, insert_op) {
TEST
(
NetOp
,
Clone
)
{
NetOp
net
;
net
.
A
d
dOp
(
net
.
A
ppen
dOp
(
std
::
unique_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
{
"empty"
,
{},
{},
{}}));
net
.
A
d
dOp
(
std
::
unique_ptr
<
framework
::
NOP
>
(
net
.
A
ppen
dOp
(
std
::
unique_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
{
"empty2"
,
{},
{},
{}}));
net
.
CompleteAddOp
(
true
);
auto
new_net_op
=
net
.
Clone
();
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
0a0f1948
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel {
std
::
uniform_real_distribution
<
T
>
dist
(
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"min"
)),
static_cast
<
T
>
(
context
.
op_
.
GetAttr
<
float
>
(
"max"
)));
for
(
ssize_t
i
=
0
;
i
<
framework
::
product
(
tensor
->
dims
());
++
i
)
{
ssize_t
size
=
framework
::
product
(
tensor
->
dims
());
for
(
ssize_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
}
}
...
...
@@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddOutput
(
"Out"
,
"The output tensor of uniform random op"
);
AddComment
(
R"DOC(Uniform random operator.
Used to initialize tensor with uniform random generator.
)DOC"
);
AddAttr
<
std
::
vector
<
int
>>
(
"dims"
,
"the dimension of random tensor"
);
...
...
paddle/operators/uniform_random_op.cu
浏览文件 @
0a0f1948
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
paddle/parameter/Parameter.h
浏览文件 @
0a0f1948
...
...
@@ -65,7 +65,10 @@ public:
size_t
getSize
()
const
{
return
config_
.
size
();
}
bool
isFullSize
()
const
{
return
this
->
getSize
()
==
bufs_
[
PARAMETER_VALUE
]
->
getSize
();
if
(
bufs_
[
PARAMETER_VALUE
])
{
return
this
->
getSize
()
==
bufs_
[
PARAMETER_VALUE
]
->
getSize
();
}
return
false
;
}
inline
bool
useGpu
()
const
{
return
useGpu_
;
}
...
...
paddle/platform/device_context.cc
浏览文件 @
0a0f1948
...
...
@@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() {
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
if
(
curand_generator_
)
{
PADDLE_ENFORCE
(
dynload
::
curandDestroyGenerator
(
curand_generator_
));
}
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
...
@@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
cudaStream_t
CUDADeviceContext
::
stream
()
{
return
stream_
;
}
curandGenerator_t
CUDADeviceContext
::
curand_generator
()
{
if
(
!
curand_generator_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
curandCreateGenerator
(
&
curand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
seed_
));
PADDLE_ENFORCE
(
dynload
::
curandSetStream
(
curand_generator_
,
stream_
));
}
return
curand_generator_
;
}
#endif // PADDLE_ONLY_CPU
}
// namespace platform
...
...
paddle/platform/device_context.h
浏览文件 @
0a0f1948
...
...
@@ -17,7 +17,6 @@ limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
...
...
@@ -40,7 +39,7 @@ class DeviceContext {
class
CPUDeviceContext
:
public
DeviceContext
{
public:
CPUDeviceContext
();
explicit
CPUDeviceContext
(
CPUPlace
);
explicit
CPUDeviceContext
(
CPUPlace
place
);
virtual
~
CPUDeviceContext
()
{}
Eigen
::
DefaultDevice
*
eigen_device
()
const
;
...
...
@@ -56,7 +55,7 @@ class EigenCudaStreamDevice;
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
GPUPlace
);
explicit
CUDADeviceContext
(
GPUPlace
place
);
virtual
~
CUDADeviceContext
();
/*! \brief Wait for all operations completion in the stream. */
...
...
@@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
();
/*! \brief Return curand handle in the device context. */
curandGenerator_t
curand_generator
();
/*! \brief Return cuda stream in the device context. */
cudaStream_t
stream
();
// clang-format on
...
...
@@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext {
private:
GPUPlace
place_
;
private:
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
private:
uint64_t
seed_
;
// clang-format off
cudaStream_t
stream_
{
nullptr
};
cudnnHandle_t
cudnn_handle_
{
nullptr
};
cublasHandle_t
cublas_handle_
{
nullptr
};
curandGenerator_t
curand_generator_
{
nullptr
};
// clang-format on
};
...
...
paddle/platform/device_context_test.cc
浏览文件 @
0a0f1948
...
...
@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
curandGenerator_t
curand_handle
=
device_context
->
curand_generator
();
ASSERT_NE
(
nullptr
,
curand_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
delete
device_context
;
}
...
...
paddle/pserver/ParameterClient2.cpp
浏览文件 @
0a0f1948
...
...
@@ -65,7 +65,6 @@ void ParameterClient2::initThreads() {
LOG
(
INFO
)
<<
"parallel_thread_num dosent need to set"
;
}
syncThreadPool_
.
reset
(
new
SyncThreadPool
(
threadNum_
));
startThreads
();
}
...
...
@@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData(
request
.
set_cost
(
cost
);
request
.
set_batch_status
(
batchStatus
);
CHECK_EQ
(
request
.
blocks_size
(),
0
);
VLOG
(
10
)
<<
"request: trainer_id: "
<<
request
.
trainer_id
()
<<
" update_mode"
<<
request
.
update_mode
()
<<
" send_back_parameter: "
<<
request
.
send_back_parameter
()
<<
" send_back_parameter_type: "
<<
request
.
send_back_parameter_type
()
<<
" num_samples: "
<<
request
.
num_samples
()
<<
" cost: "
<<
request
.
cost
()
<<
" batch_status: "
<<
request
.
batch_status
();
}
for
(
const
auto
&
segments
:
parameterSegments
)
{
const
auto
it
=
parameterMap_
.
find
(
segments
.
id
);
...
...
@@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData(
CHECK
(
sendMat
!=
nullptr
)
<<
"sendMat is nullptr"
;
syncThreadPool_
->
exec
([
&
](
int
tid
,
size_t
numThreads
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
sparseAutoGrowthMutex_
);
const
auto
&
localIndices
=
prefetchMat
->
getLocalIndices
();
/// num of sparse rows
size_t
nLocalBlocks
=
localIndices
.
size
();
uint64_t
beginDim
=
0
;
uint64_t
endDim
=
0
;
// FIXME(typhoonzero): let it resize first
prefetchMat
->
getLocalRow
(
nLocalBlocks
+
1
);
sendMat
->
getLocalRow
(
nLocalBlocks
+
1
);
for
(
size_t
row
=
0
;
row
<
nLocalBlocks
;
++
row
)
{
int64_t
blockId
=
localIndices
[
row
];
// local row -> sparse row
int
serverId
=
std
::
abs
((
blockId
+
nameHash
)
%
serviceNum_
);
...
...
@@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData(
block
->
set_begin_pos
(
row
*
blockSize
);
/// block len
block
->
set_block_size
(
endDim
-
beginDim
);
if
(
sendingPara
)
{
sendJob
->
parallelInputIovs
[
serverId
].
push_back
(
{
sendMat
->
getLocalRow
(
row
),
sizeof
(
real
)
*
(
size_t
)
blockSize
});
...
...
paddle/pserver/ParameterClient2.h
浏览文件 @
0a0f1948
...
...
@@ -583,6 +583,7 @@ protected:
#ifndef PADDLE_DISABLE_TIMER
uint64_t
forwardbackwordTime_
;
#endif
std
::
mutex
sparseAutoGrowthMutex_
;
/// map id to parameter used for decoding protobuf data
std
::
unordered_map
<
size_t
,
ParameterPtr
>
parameterMap_
;
...
...
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
0a0f1948
...
...
@@ -23,7 +23,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test
(
test_default_scope_funcs SRCS test_default_scope_funcs.py
)
py_test
(
test_operator SRCS test_operator.py
)
#
py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test
(
test_gaussian_random_op SRCS test_gaussian_random_op.py
)
py_test
(
test_uniform_random_op SRCS test_uniform_random_op.py
)
py_test
(
test_recurrent_op SRCS test_recurrent_op.py
)
py_test
(
test_sgd_op SRCS test_sgd_op.py
)
...
...
python/paddle/v2/framework/tests/op_test_util.py
浏览文件 @
0a0f1948
...
...
@@ -64,7 +64,8 @@ class OpTestMeta(type):
actual
=
numpy
.
array
(
scope
.
find_var
(
out_name
).
get_tensor
())
expect
=
self
.
outputs
[
out_name
]
self
.
assertTrue
(
numpy
.
allclose
(
actual
,
expect
),
numpy
.
allclose
(
actual
,
expect
,
atol
=
1e-05
),
"output name: "
+
out_name
+
"has diff"
)
obj
.
test_all
=
test_all
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
0a0f1948
...
...
@@ -8,9 +8,8 @@ class TestCrossEntropy(unittest.TestCase):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
# TODO this unit test is not passed
self
.
type
=
"onehot_cross_entropy"
batch_size
=
10
0
batch_size
=
3
0
class_num
=
10
X
=
numpy
.
random
.
random
((
batch_size
,
class_num
)).
astype
(
"float32"
)
label
=
5
*
numpy
.
ones
(
batch_size
).
astype
(
"int32"
)
...
...
@@ -22,9 +21,9 @@ class TestCrossEntropy(unittest.TestCase):
class
CrossEntropyGradOpTest
(
GradientChecker
):
def
test_
softmax
_grad
(
self
):
def
test_
check
_grad
(
self
):
op
=
create_op
(
"onehot_cross_entropy"
)
batch_size
=
10
0
batch_size
=
3
0
class_num
=
10
inputs
=
{
"X"
:
numpy
.
random
.
uniform
(
...
...
python/paddle/v2/framework/tests/test_net.py
浏览文件 @
0a0f1948
...
...
@@ -6,8 +6,8 @@ import unittest
def
fc
(
X
,
W
,
Y
):
ret_v
=
core
.
Net
.
create
()
ret_v
.
a
d
d_op
(
Operator
(
"mul"
,
X
=
"X"
,
Y
=
"W"
,
Out
=
"pre_activation"
))
ret_v
.
a
d
d_op
(
Operator
(
"sigmoid"
,
X
=
"pre_activation"
,
Y
=
Y
))
ret_v
.
a
ppen
d_op
(
Operator
(
"mul"
,
X
=
"X"
,
Y
=
"W"
,
Out
=
"pre_activation"
))
ret_v
.
a
ppen
d_op
(
Operator
(
"sigmoid"
,
X
=
"pre_activation"
,
Y
=
Y
))
ret_v
.
complete_add_op
(
True
)
return
ret_v
...
...
@@ -16,12 +16,12 @@ class TestNet(unittest.TestCase):
def
test_net_all
(
self
):
net
=
core
.
Net
.
create
()
op1
=
Operator
(
"add_two"
,
X
=
"X"
,
Y
=
"Y"
,
Out
=
"Out"
)
net
.
a
d
d_op
(
op1
)
net
.
a
ppen
d_op
(
op1
)
net2
=
core
.
Net
.
create
()
net2
.
a
d
d_op
(
fc
(
X
=
"X"
,
W
=
"w"
,
Y
=
"fc.out"
))
net2
.
a
ppen
d_op
(
fc
(
X
=
"X"
,
W
=
"w"
,
Y
=
"fc.out"
))
net2
.
complete_add_op
(
True
)
net
.
a
d
d_op
(
net2
)
net
.
a
ppen
d_op
(
net2
)
net
.
complete_add_op
(
True
)
expected
=
'''
...
...
python/paddle/v2/framework/tests/test_recurrent_op.py
浏览文件 @
0a0f1948
...
...
@@ -150,7 +150,7 @@ class TestRecurrentOp(unittest.TestCase):
sig_op
=
Operator
(
"sigmoid"
,
X
=
"sum"
,
Y
=
"h@alias"
)
for
op
in
[
x_fc_op
,
h_fc_op
,
sum_op
,
sig_op
]:
stepnet
.
a
d
d_op
(
op
)
stepnet
.
a
ppen
d_op
(
op
)
stepnet
.
complete_add_op
(
True
)
self
.
rnnop
.
set_stepnet
(
stepnet
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录