Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
39e129a1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
39e129a1
编写于
10月 10, 2018
作者:
X
Xin Pan
提交者:
GitHub
10月 10, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13794 from chengduoZH/set_right_shape_of_seleceted_rows_release
Set the output shape of seleceted rows in right way.
上级
b2e6e5f2
e7e69e23
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
162 addition
and
29 deletion
+162
-29
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+21
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+36
-4
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+3
-0
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+18
-2
paddle/fluid/operators/argsort_op.cc
paddle/fluid/operators/argsort_op.cc
+2
-2
paddle/fluid/operators/conv_shift_op.cc
paddle/fluid/operators/conv_shift_op.cc
+1
-1
paddle/fluid/operators/elementwise_op.h
paddle/fluid/operators/elementwise_op.h
+12
-7
paddle/fluid/operators/fake_dequantize_op.cc
paddle/fluid/operators/fake_dequantize_op.cc
+2
-1
paddle/fluid/operators/prelu_op.cc
paddle/fluid/operators/prelu_op.cc
+1
-1
paddle/fluid/operators/rnn_memory_helper_op.cc
paddle/fluid/operators/rnn_memory_helper_op.cc
+1
-1
paddle/fluid/operators/sequence_conv_op.cc
paddle/fluid/operators/sequence_conv_op.cc
+2
-2
paddle/fluid/operators/sequence_pool_op.cc
paddle/fluid/operators/sequence_pool_op.cc
+3
-2
paddle/fluid/operators/sequence_reshape_op.cc
paddle/fluid/operators/sequence_reshape_op.cc
+1
-1
paddle/fluid/operators/sequence_softmax_op.cc
paddle/fluid/operators/sequence_softmax_op.cc
+2
-1
paddle/fluid/operators/shrink_rnn_memory_op.cc
paddle/fluid/operators/shrink_rnn_memory_op.cc
+3
-3
paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc
...e/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc
+1
-1
python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
+53
-0
未找到文件。
paddle/fluid/framework/op_desc.cc
浏览文件 @
39e129a1
...
...
@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const
std
::
vector
<
std
::
string
>
&
Outputs
(
const
std
::
string
&
name
)
const
override
;
void
ShareDim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
override
{
PADDLE_ENFORCE_LT
(
i
,
Inputs
(
in
).
size
());
PADDLE_ENFORCE_LT
(
j
,
Outputs
(
out
).
size
());
const
std
::
string
&
input_n
=
Inputs
(
in
)[
i
];
const
std
::
string
&
output_n
=
Outputs
(
out
)[
j
];
PADDLE_ENFORCE
(
input_n
!=
framework
::
kEmptyVarName
,
"The %s[%d] is @EMPTY@"
,
in
,
i
);
PADDLE_ENFORCE
(
output_n
!=
framework
::
kEmptyVarName
,
"The %s[%d] is @EMPTY@"
,
out
,
j
);
auto
*
in_var
=
block_
.
FindVarRecursive
(
input_n
);
auto
*
out_var
=
block_
.
FindVarRecursive
(
output_n
);
PADDLE_ENFORCE
(
in_var
->
GetType
()
==
out_var
->
GetType
(),
"The type of %s and %s is not the same."
,
input_n
,
output_n
);
SetDim
(
output_n
,
GetDim
(
input_n
));
}
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
override
{
PADDLE_ENFORCE_LT
(
i
,
Inputs
(
in
).
size
());
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
39e129a1
...
...
@@ -542,13 +542,45 @@ class RuntimeInferShapeContext : public InferShapeContext {
return
op_
.
Outputs
(
name
);
}
void
Share
LoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
override
{
void
Share
Dim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
override
{
PADDLE_ENFORCE_LT
(
i
,
Inputs
(
in
).
size
());
PADDLE_ENFORCE_LT
(
j
,
Outputs
(
out
).
size
());
Variable
*
in_var
=
scope_
.
FindVar
(
Inputs
(
in
)[
i
]);
Variable
*
out_var
=
scope_
.
FindVar
(
Outputs
(
out
)[
j
]);
const
std
::
string
&
input_n
=
Inputs
(
in
)[
i
];
const
std
::
string
&
output_n
=
Outputs
(
out
)[
j
];
Variable
*
in_var
=
scope_
.
FindVar
(
input_n
);
Variable
*
out_var
=
scope_
.
FindVar
(
output_n
);
PADDLE_ENFORCE
(
in_var
->
Type
()
==
out_var
->
Type
(),
"The type of %s and %s is not the same."
,
output_n
,
GetDim
(
input_n
));
if
(
in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
in_sele_rows
=
in_var
->
Get
<
framework
::
SelectedRows
>
();
auto
out_sele_rows
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
out_sele_rows
->
mutable_value
()
->
Resize
(
in_sele_rows
.
value
().
dims
());
out_sele_rows
->
set_rows
(
in_sele_rows
.
rows
());
out_sele_rows
->
set_height
(
in_sele_rows
.
height
());
}
else
if
(
in_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
&
in_lod_tensor
=
in_var
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out_lod_tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
out_lod_tensor
->
Resize
(
in_lod_tensor
.
dims
());
}
else
{
PADDLE_THROW
(
"Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows."
);
}
}
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
override
{
const
std
::
vector
<
std
::
string
>&
inputs
=
Inputs
(
in
);
const
std
::
vector
<
std
::
string
>&
outputs
=
Outputs
(
out
);
PADDLE_ENFORCE_LT
(
i
,
inputs
.
size
());
PADDLE_ENFORCE_LT
(
j
,
outputs
.
size
());
Variable
*
in_var
=
scope_
.
FindVar
(
inputs
.
at
(
i
));
if
(
!
in_var
->
IsType
<
LoDTensor
>
())
return
;
Variable
*
out_var
=
scope_
.
FindVar
(
outputs
.
at
(
j
));
PADDLE_ENFORCE
(
out_var
->
IsType
<
LoDTensor
>
(),
"The %d-th output of Output(%s) must be LoDTensor."
,
j
,
out
);
auto
in_tensor
=
in_var
->
Get
<
LoDTensor
>
();
...
...
paddle/fluid/framework/shape_inference.h
浏览文件 @
39e129a1
...
...
@@ -58,6 +58,9 @@ class InferShapeContext {
void
ShareLoDs
(
const
std
::
string
&
in
,
const
std
::
string
&
out
)
const
;
virtual
void
ShareDim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
=
0
;
virtual
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
=
0
;
...
...
paddle/fluid/operators/activation_op.cc
浏览文件 @
39e129a1
...
...
@@ -79,7 +79,7 @@ class ActivationOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
S
etOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
)
);
ctx
->
S
hareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
@@ -90,12 +90,26 @@ class ActivationOp : public framework::OperatorWithKernel {
}
};
class
ActivationOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
x_name
=
op_desc
.
Input
(
"X"
)[
0
];
auto
out_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto
&
x
=
block
->
FindRecursiveOrCreateVar
(
x_name
);
auto
&
out
=
block
->
FindRecursiveOrCreateVar
(
out_name
);
out
.
SetType
(
x
.
GetType
());
out
.
SetDataType
(
x
.
GetDataType
());
}
};
class
ActivationOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Out"
));
ctx
->
ShareDim
(
"Out"
,
framework
::
GradVarName
(
"X"
));
ctx
->
ShareLoD
(
"Out"
,
framework
::
GradVarName
(
"X"
));
}
protected:
...
...
@@ -524,12 +538,14 @@ namespace ops = paddle::operators;
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
...
...
paddle/fluid/operators/argsort_op.cc
浏览文件 @
39e129a1
...
...
@@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel {
"-rank(Input(X)) (%d)."
,
axis
,
num_dims
);
ctx
->
S
etOutputDim
(
"Out"
,
in_dims
);
ctx
->
S
etOutputDim
(
"Indices"
,
in_dims
);
ctx
->
S
hareDim
(
"X"
,
"Out"
);
ctx
->
S
hareDim
(
"X"
,
"Indices"
);
ctx
->
ShareLoD
(
"X"
,
"Out"
);
ctx
->
ShareLoD
(
"X"
,
"Indices"
);
}
...
...
paddle/fluid/operators/conv_shift_op.cc
浏览文件 @
39e129a1
...
...
@@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE
(
y_dims
[
1
],
x_dims
[
1
],
"The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X)."
);
ctx
->
S
etOutputDim
(
"Out"
,
x_dims
);
ctx
->
S
hareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
...
...
paddle/fluid/operators/elementwise_op.h
浏览文件 @
39e129a1
...
...
@@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto
y_dim
=
ctx
->
GetInputDim
(
"Y"
);
PADDLE_ENFORCE_GE
(
x_dim
.
size
(),
y_dim
.
size
(),
"Rank of first input must >= rank of second input."
);
ctx
->
SetOutputDim
(
"Out"
,
x_dim
);
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
@@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
auto
&
x
=
block
->
FindRecursiveOrCreateVar
(
x_name
);
auto
&
out
=
block
->
FindRecursiveOrCreateVar
(
out_name
);
out
.
SetType
(
x
.
GetType
());
out
.
SetDataType
(
x
.
GetDataType
());
}
};
...
...
@@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
auto
y_grad_name
=
framework
::
GradVarName
(
"Y"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
SetOutputDim
(
x_grad_name
,
x_dims
);
ctx
->
ShareDim
(
"X"
,
/*->*/
x_grad_name
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
x_grad_name
);
}
if
(
ctx
->
HasOutput
(
y_grad_name
))
{
ctx
->
SetOutputDim
(
y_grad_name
,
y_dims
);
ctx
->
ShareDim
(
"Y"
,
/*->*/
y_grad_name
);
ctx
->
ShareLoD
(
"Y"
,
/*->*/
y_grad_name
);
}
}
...
...
@@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
)
);
ctx
->
S
etOutputDim
(
x_grad_name
,
out_dims
);
ctx
->
ShareDim
(
framework
::
GradVarName
(
"Out"
),
/*->*/
x_grad_name
);
ctx
->
S
hareLoD
(
framework
::
GradVarName
(
"Out"
),
/*->*/
x_grad_name
);
}
auto
y_grad_name
=
framework
::
GradVarName
(
"Y"
);
if
(
ctx
->
HasOutput
(
y_grad_name
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
ctx
->
SetOutputDim
(
y_grad_name
,
y_dims
);
ctx
->
ShareDim
(
"Y"
,
/*->*/
y_grad_name
);
ctx
->
ShareLoD
(
"Y"
,
/*->*/
y_grad_name
);
}
}
};
...
...
paddle/fluid/operators/fake_dequantize_op.cc
浏览文件 @
39e129a1
...
...
@@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
"Input(X) of FakeDequantizeMaxAbsOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FakeDequantizeMaxAbsOp should not be null."
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
...
...
paddle/fluid/operators/prelu_op.cc
浏览文件 @
39e129a1
...
...
@@ -49,7 +49,7 @@ class PReluOp : public framework::OperatorWithKernel {
}
else
{
PADDLE_THROW
(
"Unkown mode %s"
,
mode
);
}
ctx
->
S
etOutputDim
(
"Out"
,
x_dim
);
ctx
->
S
hareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
paddle/fluid/operators/rnn_memory_helper_op.cc
浏览文件 @
39e129a1
...
...
@@ -54,7 +54,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
"Input(X) of rnn_memory_helper op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output of rnn_memory_helper op should not be null."
);
ctx
->
S
etOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
)
);
ctx
->
S
hareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
...
...
paddle/fluid/operators/sequence_conv_op.cc
浏览文件 @
39e129a1
...
...
@@ -90,8 +90,8 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
ctx
->
GetInputDim
(
"PaddingData"
));
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
S
etOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
ctx
->
S
hareDim
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Filter"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Filter"
),
...
...
paddle/fluid/operators/sequence_pool_op.cc
浏览文件 @
39e129a1
...
...
@@ -102,8 +102,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
for
(
int64_t
i
=
1
;
i
<
og_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
og_dims
[
i
],
x_dims
[
i
],
"The dimension mismatch."
);
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
ctx
->
ShareDim
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
}
protected:
...
...
paddle/fluid/operators/sequence_reshape_op.cc
浏览文件 @
39e129a1
...
...
@@ -92,7 +92,7 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequenceReshapeGradOp should not be null."
);
ctx
->
S
etOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
S
hareDim
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
}
};
...
...
paddle/fluid/operators/sequence_softmax_op.cc
浏览文件 @
39e129a1
...
...
@@ -27,7 +27,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
"Input(X) of SequenceSoftmaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of SequenceSoftmaxOp should not be null."
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
paddle/fluid/operators/shrink_rnn_memory_op.cc
浏览文件 @
39e129a1
...
...
@@ -151,9 +151,9 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"X"
));
PADDLE_ENFORCE
(
context
->
HasOutput
(
framework
::
GradVarName
(
"X"
)));
context
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
context
->
GetInputDim
(
"X"
));
context
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
context
->
ShareDim
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
context
->
ShareLoD
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
}
};
...
...
paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc
浏览文件 @
39e129a1
...
...
@@ -40,7 +40,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
"The 2nd dimension of Input(X) and Input(Label) should "
"be equal."
);
ctx
->
S
etOutputDim
(
"Out"
,
x_dims
);
ctx
->
S
hareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
...
...
python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
浏览文件 @
39e129a1
...
...
@@ -16,6 +16,8 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
class
ElementwiseMulOp
(
OpTest
):
...
...
@@ -115,5 +117,56 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
}
class
TestElementWiseMulSelectedRows
(
OpTest
):
def
setUp
(
self
):
self
.
rows
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]
self
.
feature
=
12
self
.
height
=
100
self
.
input_shape
=
(
len
(
self
.
rows
),
self
.
feature
)
def
prepare_input
(
self
,
scope
,
place
):
self
.
input
=
{
"X"
:
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float32"
),
"Y"
:
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float32"
)
}
def
init_input
(
in_name
):
x_selected_rows
=
scope
.
var
(
in_name
).
get_selected_rows
()
x_selected_rows
.
set_height
(
self
.
height
)
x_selected_rows
.
set_rows
(
self
.
rows
)
x_array
=
self
.
input
[
in_name
]
x_tensor
=
x_selected_rows
.
get_tensor
()
x_tensor
.
set
(
x_array
,
place
)
init_input
(
"X"
)
init_input
(
"Y"
)
def
create_out_selected_row
(
self
,
scope
):
return
scope
.
var
(
'Out'
).
get_selected_rows
()
def
check_result
(
self
,
out_selected_rows
):
assert
out_selected_rows
.
height
()
==
self
.
height
assert
out_selected_rows
.
rows
()
==
self
.
rows
out_tensor
=
np
.
array
(
out_selected_rows
.
get_tensor
())
assert
out_tensor
.
shape
==
self
.
input_shape
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
self
.
prepare_input
(
scope
,
place
)
out_selected_rows
=
self
.
create_out_selected_row
(
scope
)
out_selected_rows
.
set_height
(
0
)
out_selected_rows
.
set_rows
([])
elementwise_mul
=
Operator
(
"elementwise_mul"
,
X
=
'X'
,
Y
=
'Y'
,
Out
=
'Out'
)
elementwise_mul
.
run
(
scope
,
place
)
self
.
check_result
(
out_selected_rows
)
def
test_elewisemul_with_selected_rows_input
(
self
):
places
=
[
core
.
CPUPlace
()]
for
place
in
places
:
self
.
check_with_place
(
place
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录