Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
9445502f
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9445502f
编写于
10月 10, 2018
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into trt_dy_lib
test=develop
上级
d347ea68
e1904ac2
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
195 addition
and
28 deletion
+195
-28
.gitignore
.gitignore
+1
-0
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+21
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+30
-0
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+3
-0
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+18
-2
paddle/fluid/operators/argsort_op.cc
paddle/fluid/operators/argsort_op.cc
+2
-2
paddle/fluid/operators/conv_shift_op.cc
paddle/fluid/operators/conv_shift_op.cc
+1
-1
paddle/fluid/operators/elementwise_op.h
paddle/fluid/operators/elementwise_op.h
+12
-7
paddle/fluid/operators/fake_dequantize_op.cc
paddle/fluid/operators/fake_dequantize_op.cc
+2
-1
paddle/fluid/operators/lookup_table_op.cc
paddle/fluid/operators/lookup_table_op.cc
+1
-0
paddle/fluid/operators/prelu_op.cc
paddle/fluid/operators/prelu_op.cc
+1
-1
paddle/fluid/operators/rnn_memory_helper_op.cc
paddle/fluid/operators/rnn_memory_helper_op.cc
+1
-1
paddle/fluid/operators/sequence_conv_op.cc
paddle/fluid/operators/sequence_conv_op.cc
+2
-2
paddle/fluid/operators/sequence_pool_op.cc
paddle/fluid/operators/sequence_pool_op.cc
+3
-2
paddle/fluid/operators/sequence_reshape_op.cc
paddle/fluid/operators/sequence_reshape_op.cc
+1
-1
paddle/fluid/operators/sequence_softmax_op.cc
paddle/fluid/operators/sequence_softmax_op.cc
+2
-1
paddle/fluid/operators/shrink_rnn_memory_op.cc
paddle/fluid/operators/shrink_rnn_memory_op.cc
+3
-3
paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc
...e/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+36
-2
python/paddle/fluid/lod_tensor.py
python/paddle/fluid/lod_tensor.py
+1
-1
python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
+53
-0
未找到文件。
.gitignore
浏览文件 @
9445502f
...
...
@@ -25,5 +25,6 @@ third_party/
bazel-*
third_party/
build_*
# clion workspace.
cmake-build-*
paddle/fluid/framework/op_desc.cc
浏览文件 @
9445502f
...
...
@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const
std
::
vector
<
std
::
string
>
&
Outputs
(
const
std
::
string
&
name
)
const
override
;
void
ShareDim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
override
{
PADDLE_ENFORCE_LT
(
i
,
Inputs
(
in
).
size
());
PADDLE_ENFORCE_LT
(
j
,
Outputs
(
out
).
size
());
const
std
::
string
&
input_n
=
Inputs
(
in
)[
i
];
const
std
::
string
&
output_n
=
Outputs
(
out
)[
j
];
PADDLE_ENFORCE
(
input_n
!=
framework
::
kEmptyVarName
,
"The %s[%d] is @EMPTY@"
,
in
,
i
);
PADDLE_ENFORCE
(
output_n
!=
framework
::
kEmptyVarName
,
"The %s[%d] is @EMPTY@"
,
out
,
j
);
auto
*
in_var
=
block_
.
FindVarRecursive
(
input_n
);
auto
*
out_var
=
block_
.
FindVarRecursive
(
output_n
);
PADDLE_ENFORCE
(
in_var
->
GetType
()
==
out_var
->
GetType
(),
"The type of %s and %s is not the same."
,
input_n
,
output_n
);
SetDim
(
output_n
,
GetDim
(
input_n
));
}
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
override
{
PADDLE_ENFORCE_LT
(
i
,
Inputs
(
in
).
size
());
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
9445502f
...
...
@@ -542,6 +542,36 @@ class RuntimeInferShapeContext : public InferShapeContext {
return
op_
.
Outputs
(
name
);
}
void
ShareDim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
override
{
PADDLE_ENFORCE_LT
(
i
,
Inputs
(
in
).
size
());
PADDLE_ENFORCE_LT
(
j
,
Outputs
(
out
).
size
());
const
std
::
string
&
input_n
=
Inputs
(
in
)[
i
];
const
std
::
string
&
output_n
=
Outputs
(
out
)[
j
];
Variable
*
in_var
=
scope_
.
FindVar
(
input_n
);
Variable
*
out_var
=
scope_
.
FindVar
(
output_n
);
PADDLE_ENFORCE
(
in_var
->
Type
()
==
out_var
->
Type
(),
"The type of %s and %s is not the same."
,
output_n
,
GetDim
(
input_n
));
if
(
in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
in_sele_rows
=
in_var
->
Get
<
framework
::
SelectedRows
>
();
auto
out_sele_rows
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
out_sele_rows
->
mutable_value
()
->
Resize
(
in_sele_rows
.
value
().
dims
());
out_sele_rows
->
set_rows
(
in_sele_rows
.
rows
());
out_sele_rows
->
set_height
(
in_sele_rows
.
height
());
}
else
if
(
in_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
&
in_lod_tensor
=
in_var
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out_lod_tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
out_lod_tensor
->
Resize
(
in_lod_tensor
.
dims
());
}
else
{
PADDLE_THROW
(
"Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows."
);
}
}
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
override
{
const
std
::
vector
<
std
::
string
>&
inputs
=
Inputs
(
in
);
...
...
paddle/fluid/framework/shape_inference.h
浏览文件 @
9445502f
...
...
@@ -56,6 +56,9 @@ class InferShapeContext {
virtual
const
std
::
vector
<
std
::
string
>
&
Outputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
void
ShareDim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
=
0
;
virtual
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
=
0
;
...
...
paddle/fluid/operators/activation_op.cc
浏览文件 @
9445502f
...
...
@@ -80,7 +80,7 @@ class ActivationOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
S
etOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
)
);
ctx
->
S
hareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
@@ -91,12 +91,26 @@ class ActivationOp : public framework::OperatorWithKernel {
}
};
class
ActivationOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
x_name
=
op_desc
.
Input
(
"X"
)[
0
];
auto
out_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto
&
x
=
block
->
FindRecursiveOrCreateVar
(
x_name
);
auto
&
out
=
block
->
FindRecursiveOrCreateVar
(
out_name
);
out
.
SetType
(
x
.
GetType
());
out
.
SetDataType
(
x
.
GetDataType
());
}
};
class
ActivationOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Out"
));
ctx
->
ShareDim
(
"Out"
,
framework
::
GradVarName
(
"X"
));
ctx
->
ShareLoD
(
"Out"
,
framework
::
GradVarName
(
"X"
));
}
protected:
...
...
@@ -525,12 +539,14 @@ namespace ops = paddle::operators;
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
...
...
paddle/fluid/operators/argsort_op.cc
浏览文件 @
9445502f
...
...
@@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel {
"-rank(Input(X)) (%d)."
,
axis
,
num_dims
);
ctx
->
S
etOutputDim
(
"Out"
,
in_dims
);
ctx
->
S
etOutputDim
(
"Indices"
,
in_dims
);
ctx
->
S
hareDim
(
"X"
,
"Out"
);
ctx
->
S
hareDim
(
"X"
,
"Indices"
);
ctx
->
ShareLoD
(
"X"
,
"Out"
);
ctx
->
ShareLoD
(
"X"
,
"Indices"
);
}
...
...
paddle/fluid/operators/conv_shift_op.cc
浏览文件 @
9445502f
...
...
@@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE
(
y_dims
[
1
],
x_dims
[
1
],
"The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X)."
);
ctx
->
S
etOutputDim
(
"Out"
,
x_dims
);
ctx
->
S
hareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
...
...
paddle/fluid/operators/elementwise_op.h
浏览文件 @
9445502f
...
...
@@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto
y_dim
=
ctx
->
GetInputDim
(
"Y"
);
PADDLE_ENFORCE_GE
(
x_dim
.
size
(),
y_dim
.
size
(),
"Rank of first input must >= rank of second input."
);
ctx
->
SetOutputDim
(
"Out"
,
x_dim
);
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
@@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
auto
&
x
=
block
->
FindRecursiveOrCreateVar
(
x_name
);
auto
&
out
=
block
->
FindRecursiveOrCreateVar
(
out_name
);
out
.
SetType
(
x
.
GetType
());
out
.
SetDataType
(
x
.
GetDataType
());
}
};
...
...
@@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
auto
y_grad_name
=
framework
::
GradVarName
(
"Y"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
SetOutputDim
(
x_grad_name
,
x_dims
);
ctx
->
ShareDim
(
"X"
,
/*->*/
x_grad_name
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
x_grad_name
);
}
if
(
ctx
->
HasOutput
(
y_grad_name
))
{
ctx
->
SetOutputDim
(
y_grad_name
,
y_dims
);
ctx
->
ShareDim
(
"Y"
,
/*->*/
y_grad_name
);
ctx
->
ShareLoD
(
"Y"
,
/*->*/
y_grad_name
);
}
}
...
...
@@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
)
);
ctx
->
S
etOutputDim
(
x_grad_name
,
out_dims
);
ctx
->
ShareDim
(
framework
::
GradVarName
(
"Out"
),
/*->*/
x_grad_name
);
ctx
->
S
hareLoD
(
framework
::
GradVarName
(
"Out"
),
/*->*/
x_grad_name
);
}
auto
y_grad_name
=
framework
::
GradVarName
(
"Y"
);
if
(
ctx
->
HasOutput
(
y_grad_name
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
ctx
->
SetOutputDim
(
y_grad_name
,
y_dims
);
ctx
->
ShareDim
(
"Y"
,
/*->*/
y_grad_name
);
ctx
->
ShareLoD
(
"Y"
,
/*->*/
y_grad_name
);
}
}
};
...
...
paddle/fluid/operators/fake_dequantize_op.cc
浏览文件 @
9445502f
...
...
@@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
"Input(X) of FakeDequantizeMaxAbsOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FakeDequantizeMaxAbsOp should not be null."
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
...
...
paddle/fluid/operators/lookup_table_op.cc
浏览文件 @
9445502f
...
...
@@ -137,6 +137,7 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
<<
" is set to LoDTensor"
;
block
->
Var
(
out_var_name
)
->
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
block
->
Var
(
out_var_name
)
->
SetDataType
(
block
->
Var
(
"W"
)
->
GetDataType
());
}
};
...
...
paddle/fluid/operators/prelu_op.cc
浏览文件 @
9445502f
...
...
@@ -49,7 +49,7 @@ class PReluOp : public framework::OperatorWithKernel {
}
else
{
PADDLE_THROW
(
"Unkown mode %s"
,
mode
);
}
ctx
->
S
etOutputDim
(
"Out"
,
x_dim
);
ctx
->
S
hareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
paddle/fluid/operators/rnn_memory_helper_op.cc
浏览文件 @
9445502f
...
...
@@ -54,7 +54,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
"Input(X) of rnn_memory_helper op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output of rnn_memory_helper op should not be null."
);
ctx
->
S
etOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
)
);
ctx
->
S
hareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
...
...
paddle/fluid/operators/sequence_conv_op.cc
浏览文件 @
9445502f
...
...
@@ -90,8 +90,8 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
ctx
->
GetInputDim
(
"PaddingData"
));
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
S
etOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
ctx
->
S
hareDim
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Filter"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Filter"
),
...
...
paddle/fluid/operators/sequence_pool_op.cc
浏览文件 @
9445502f
...
...
@@ -102,8 +102,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
for
(
int64_t
i
=
1
;
i
<
og_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
og_dims
[
i
],
x_dims
[
i
],
"The dimension mismatch."
);
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
ctx
->
ShareDim
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
}
protected:
...
...
paddle/fluid/operators/sequence_reshape_op.cc
浏览文件 @
9445502f
...
...
@@ -92,7 +92,7 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequenceReshapeGradOp should not be null."
);
ctx
->
S
etOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
S
hareDim
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
}
};
...
...
paddle/fluid/operators/sequence_softmax_op.cc
浏览文件 @
9445502f
...
...
@@ -27,7 +27,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
"Input(X) of SequenceSoftmaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of SequenceSoftmaxOp should not be null."
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
paddle/fluid/operators/shrink_rnn_memory_op.cc
浏览文件 @
9445502f
...
...
@@ -151,9 +151,9 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"X"
));
PADDLE_ENFORCE
(
context
->
HasOutput
(
framework
::
GradVarName
(
"X"
)));
context
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
context
->
GetInputDim
(
"X"
));
context
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
context
->
ShareDim
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
context
->
ShareLoD
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
}
};
...
...
paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc
浏览文件 @
9445502f
...
...
@@ -40,7 +40,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
"The 2nd dimension of Input(X) and Input(Label) should "
"be equal."
);
ctx
->
S
etOutputDim
(
"Out"
,
x_dims
);
ctx
->
S
hareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
9445502f
...
...
@@ -620,7 +620,23 @@ All parameter, weight, gradient are variables in Paddle.
// -- python binds for parallel executor.
py
::
class_
<
ParallelExecutor
>
pe
(
m
,
"ParallelExecutor"
);
py
::
class_
<
ExecutionStrategy
>
exec_strategy
(
pe
,
"ExecutionStrategy"
);
py
::
class_
<
ExecutionStrategy
>
exec_strategy
(
pe
,
"ExecutionStrategy"
,
R"DOC(
ExecutionStrategy allows the user to more preciously control how to run
the program in ParallelExecutor by setting the property.
The available properties include:
use_cuda (bool): Whether to use CUDA or not. Default True.
num_threads (int): The number of threads that used to run the
operators in ParallelExecutor. If it is not set, it will be
set in ParallelExecutor according to the device count.
Default 0.
allow_op_delay (bool): Whether to delay the communication operators
to run. Default False.
num_iteration_per_drop_scope (int): how many iterations between
the two dropping local scopes. Default 100.
)DOC"
);
exec_strategy
.
def
(
py
::
init
())
.
def_property
(
"num_threads"
,
...
...
@@ -658,7 +674,25 @@ All parameter, weight, gradient are variables in Paddle.
:
ExecutionStrategy
::
kDefault
;
});
py
::
class_
<
BuildStrategy
>
build_strategy
(
pe
,
"BuildStrategy"
);
py
::
class_
<
BuildStrategy
>
build_strategy
(
pe
,
"BuildStrategy"
,
R"DOC(
BuildStrategy allows the user to more preciously control how to
build the SSA Graph in ParallelExecutor by setting the property.
The available properties include:
reduce_strategy (str): There are two reduce strategies, 'AllReduce'
and 'Reduce'. If you want that all parameters will be optimized
on all devices, you can choose 'AllReduce'; if you choose
'Reduce', all parameters will be evenly allocated to different
devices for optimization, and then broadcast the optimized
parameter to other devices. Default 'AllReduce'.
gradient_scale_strategy (str): There are two ways of defining loss@grad,
'CoeffNumDevice' and 'Customized'. By default, ParallelExecutor
sets the loss@grad according to the number of devices. If you want
to customize loss@grad, you can choose 'Customized'.
Default 'CoeffNumDevice'.
debug_graphviz_path (str): Whether to write the SSA Graph to file in the
form of graphviz. It is useful for debugging. Default "".
)DOC"
);
py
::
enum_
<
BuildStrategy
::
ReduceStrategy
>
(
build_strategy
,
"ReduceStrategy"
)
.
value
(
"Reduce"
,
BuildStrategy
::
ReduceStrategy
::
kReduce
)
...
...
python/paddle/fluid/lod_tensor.py
浏览文件 @
9445502f
...
...
@@ -74,7 +74,7 @@ def create_lod_tensor(data, recursive_seq_lens, place):
assert
[
new_recursive_seq_lens
]
==
recursive_seq_lens
,
"data and recursive_seq_lens do not match"
flattened_data
=
np
.
concatenate
(
data
,
axis
=
0
)
.
astype
(
"int64"
)
flattened_data
=
np
.
concatenate
(
data
,
axis
=
0
)
flattened_data
=
flattened_data
.
reshape
([
len
(
flattened_data
),
1
])
return
create_lod_tensor
(
flattened_data
,
recursive_seq_lens
,
place
)
elif
isinstance
(
data
,
np
.
ndarray
):
...
...
python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
浏览文件 @
9445502f
...
...
@@ -16,6 +16,8 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
class
ElementwiseMulOp
(
OpTest
):
...
...
@@ -115,5 +117,56 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
}
class
TestElementWiseMulSelectedRows
(
OpTest
):
def
setUp
(
self
):
self
.
rows
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]
self
.
feature
=
12
self
.
height
=
100
self
.
input_shape
=
(
len
(
self
.
rows
),
self
.
feature
)
def
prepare_input
(
self
,
scope
,
place
):
self
.
input
=
{
"X"
:
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float32"
),
"Y"
:
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float32"
)
}
def
init_input
(
in_name
):
x_selected_rows
=
scope
.
var
(
in_name
).
get_selected_rows
()
x_selected_rows
.
set_height
(
self
.
height
)
x_selected_rows
.
set_rows
(
self
.
rows
)
x_array
=
self
.
input
[
in_name
]
x_tensor
=
x_selected_rows
.
get_tensor
()
x_tensor
.
set
(
x_array
,
place
)
init_input
(
"X"
)
init_input
(
"Y"
)
def
create_out_selected_row
(
self
,
scope
):
return
scope
.
var
(
'Out'
).
get_selected_rows
()
def
check_result
(
self
,
out_selected_rows
):
assert
out_selected_rows
.
height
()
==
self
.
height
assert
out_selected_rows
.
rows
()
==
self
.
rows
out_tensor
=
np
.
array
(
out_selected_rows
.
get_tensor
())
assert
out_tensor
.
shape
==
self
.
input_shape
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
self
.
prepare_input
(
scope
,
place
)
out_selected_rows
=
self
.
create_out_selected_row
(
scope
)
out_selected_rows
.
set_height
(
0
)
out_selected_rows
.
set_rows
([])
elementwise_mul
=
Operator
(
"elementwise_mul"
,
X
=
'X'
,
Y
=
'Y'
,
Out
=
'Out'
)
elementwise_mul
.
run
(
scope
,
place
)
self
.
check_result
(
out_selected_rows
)
def
test_elewisemul_with_selected_rows_input
(
self
):
places
=
[
core
.
CPUPlace
()]
for
place
in
places
:
self
.
check_with_place
(
place
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录