Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c7db6e8d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c7db6e8d
编写于
9月 13, 2017
作者:
Z
zchen0211
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cond op passed
上级
b8e75c1f
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
198 addition
and
189 deletion
+198
-189
paddle/operators/cond_op.cc
paddle/operators/cond_op.cc
+163
-3
paddle/operators/cond_op.h
paddle/operators/cond_op.h
+9
-164
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+1
-0
python/paddle/v2/framework/op.py
python/paddle/v2/framework/op.py
+3
-3
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+1
-0
python/paddle/v2/framework/tests/test_cond_op.py
python/paddle/v2/framework/tests/test_cond_op.py
+21
-19
未找到文件。
paddle/operators/cond_op.cc
浏览文件 @
c7db6e8d
...
...
@@ -13,15 +13,175 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cond_op.h"
#include <cstring>
#include <sstream>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/gather.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/scatter.h"
namespace
paddle
{
namespace
operators
{
class
CondOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
using
Scope
=
framework
::
Scope
;
using
Variable
=
framework
::
Variable
;
using
Tensor
=
framework
::
Tensor
;
using
DDim
=
framework
::
DDim
;
void
CondOp
::
CreateScope
(
const
Scope
&
scope
)
const
{
auto
sub_scopes_var
=
scope
.
FindVar
(
"SubScopes"
);
PADDLE_ENFORCE
(
sub_scopes_var
!=
nullptr
,
""
);
auto
sub_scopes
=
sub_scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
auto
&
sub_scope
=
scope
.
NewScope
();
sub_scopes
->
push_back
(
&
sub_scope
);
}
void
CondOp
::
CreateIndexTensor
(
const
Scope
&
scope
)
const
{
auto
index_tensors_var
=
scope
.
FindVar
(
"IndexTensors"
);
PADDLE_ENFORCE
(
index_tensors_var
!=
nullptr
,
""
);
auto
&
index_tensors
=
*
index_tensors_var
->
GetMutable
<
std
::
vector
<
Tensor
>>
();
index_tensors
.
push_back
(
Tensor
());
}
void
CondOp
::
InferShape
(
const
Scope
&
scope
)
const
{
auto
sub_scopes_var
=
scope
.
FindVar
(
"SubScopes"
);
PADDLE_ENFORCE_NOT_NULL
(
sub_scopes_var
);
auto
&
sub_scopes
=
*
sub_scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// Create two sub scopes for true and false branches
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
// branch
CreateScope
(
scope
);
// Create two tensors for true and false indices
// index_tensors[0] for the true branch and index_tensors[1] for the false
// branch
CreateIndexTensor
(
scope
);
PADDLE_ENFORCE
(
!
Inputs
(
"Xs"
).
empty
(),
"Inputs can't be empty"
);
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
// Create a new tensor in sub-scope for input-type tensor
Variable
*
v
=
sub_scopes
[
i
]
->
NewVar
(
input
);
Tensor
*
sub_input
=
v
->
GetMutable
<
Tensor
>
();
sub_input
->
Resize
(
scope
.
FindVar
(
input
)
->
GetMutable
<
Tensor
>
()
->
dims
());
}
for
(
auto
&
output
:
(
*
sub_net_op_
[
i
]).
Outputs
())
{
for
(
auto
&
var_name
:
output
.
second
)
{
sub_scopes
[
i
]
->
NewVar
(
var_name
);
}
}
// each net calls InferShape
sub_net_op_
[
i
]
->
InferShape
(
*
sub_scopes
[
i
]);
}
for
(
auto
&
output
:
Outputs
(
"Outs"
))
{
Tensor
*
tensor_t_out
=
sub_scopes
[
0
]
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE_NOT_NULL
(
tensor_t_out
,
"True output should be NULL"
);
Tensor
*
tensor_f_out
=
sub_scopes
[
1
]
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE_NOT_NULL
(
tensor_f_out
,
"True output should be NULL"
);
auto
*
tensor_out_var
=
scope
.
FindVar
(
output
);
PADDLE_ENFORCE_NOT_NULL
(
tensor_out_var
,
"Output not found"
);
Tensor
*
tensor_out
=
tensor_out_var
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE_NOT_NULL
(
tensor_t_out
,
"True output should be NULL"
);
// check output size should be same
PADDLE_ENFORCE_EQ
(
tensor_t_out
->
dims
(),
tensor_f_out
->
dims
(),
"Outputs not of the same shape"
);
tensor_out
->
Resize
(
tensor_t_out
->
dims
());
tensor_out
->
mutable_data
<
float
>
(
tensor_out
->
dims
(),
platform
::
CPUPlace
());
}
}
void
CondOp
::
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
auto
sub_scopes
=
scope
.
FindVar
(
"SubScopes"
)
->
Get
<
std
::
vector
<
Scope
*>>
();
auto
index_tensors
=
scope
.
FindVar
(
"IndexTensors"
)
->
Get
<
std
::
vector
<
Tensor
>>
();
std
::
string
cond_name
=
Input
(
"Cond"
);
Variable
*
cond_var
=
scope
.
FindVar
(
cond_name
);
PADDLE_ENFORCE_NOT_NULL
(
cond_var
);
const
Tensor
*
cond
=
cond_var
->
GetMutable
<
Tensor
>
();
// Step 1: get the true/false index at runtime
// index_[0]: vector<int>, contains all index for cond[i] == true
// index_[1]: vector<int>, contains all index for cond[i] == false
for
(
int
i
=
0
;
i
<
2
;
++
i
)
index_
[
i
].
clear
();
const
int
*
cond_data
=
cond
->
data
<
int
>
();
for
(
int
i
=
0
;
i
<
cond
->
dims
()[
0
];
++
i
)
{
if
(
cond_data
[
i
])
index_
[
0
].
push_back
(
i
);
else
index_
[
1
].
push_back
(
i
);
}
// put index_[0] and index_[1] into two tensors:
// index_tensor_[0] and index_tensor_[1]
DDim
dim
=
paddle
::
framework
::
make_ddim
({
0
});
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
dim
[
0
]
=
index_
[
i
].
size
();
int
*
tmp_ptr
=
index_tensors
[
i
].
mutable_data
<
int
>
(
dim
,
platform
::
CPUPlace
());
index_tensors
[
i
].
Resize
(
dim
);
memcpy
(
tmp_ptr
,
index_
[
i
].
data
(),
dim
[
0
]
*
sizeof
(
int
));
}
// Step 2: collect data by calling gather
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// i= 0/i for True and False branches respectively
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
// find Tensor
Variable
*
v
=
scope
.
FindVar
(
input
);
PADDLE_ENFORCE_NOT_NULL
(
v
);
Tensor
*
tensor_parent
=
v
->
GetMutable
<
Tensor
>
();
v
=
sub_scopes
[
i
]
->
FindVar
(
input
);
PADDLE_ENFORCE_NOT_NULL
(
v
);
Tensor
*
tensor_child
=
v
->
GetMutable
<
Tensor
>
();
// Resize child
DDim
dim
=
tensor_child
->
dims
();
dim
[
0
]
=
index_
[
i
].
size
();
tensor_child
->
Resize
(
dim
);
tensor_child
->
mutable_data
<
float
>
(
dim
,
platform
::
CPUPlace
());
Gather
<
float
>
(
dev_ctx
.
GetPlace
(),
tensor_parent
,
&
index_tensors
[
i
],
tensor_child
);
}
}
// Step 3: run
for
(
int
i
=
0
;
i
<
2
;
++
i
)
sub_net_op_
[
i
]
->
Run
(
*
sub_scopes
[
i
],
dev_ctx
);
// Step 4: merge output results
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// i= 0/i for True and False branches respectively
for
(
auto
&
output
:
Outputs
(
"Outs"
))
{
// find Tensor
Variable
*
v
=
scope
.
FindVar
(
output
);
PADDLE_ENFORCE_NOT_NULL
(
v
);
Tensor
*
tensor_parent
=
v
->
GetMutable
<
Tensor
>
();
v
=
sub_scopes
[
i
]
->
FindVar
(
output
);
PADDLE_ENFORCE_NOT_NULL
(
v
);
Tensor
*
tensor_child
=
v
->
GetMutable
<
Tensor
>
();
ScatterUpdate
<
float
>
(
dev_ctx
.
GetPlace
(),
tensor_child
,
&
index_tensors
[
i
],
tensor_parent
);
}
}
}
class
CondOpProtoAndCheckerMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
CondOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
CondOpProtoAndCheckerMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"Cond"
,
"The condition, which is a bool vector"
);
AddInput
(
"Xs"
,
"Inputs of Subnets"
).
AsDuplicable
();
...
...
@@ -41,5 +201,5 @@ Out[i] = subnet_t[i], if Cond[i] == false
}
// namespace operators
}
// namespace paddle
REGISTER_OP_WITHOUT_GRADIENT
(
cond
_op
,
paddle
::
operators
::
CondOp
,
REGISTER_OP_WITHOUT_GRADIENT
(
cond
,
paddle
::
operators
::
CondOp
,
paddle
::
operators
::
CondOpProtoAndCheckerMaker
);
paddle/operators/cond_op.h
浏览文件 @
c7db6e8d
...
...
@@ -19,22 +19,19 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/gather.h"
#include "paddle/operators/scatter.h"
#include "paddle/operators/net_op.h"
namespace
paddle
{
namespace
operators
{
using
namespace
paddle
::
framework
;
class
CondOp
:
public
OperatorBase
{
class
CondOp
:
public
framework
::
OperatorBase
{
public:
CondOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
CondOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
index_
.
resize
(
2
);
sub_net_op_
.
resize
(
2
);
LOG
(
INFO
)
<<
"Initialization Done."
;
}
CondOp
(
const
CondOp
&
o
)
...
...
@@ -44,87 +41,14 @@ class CondOp : public OperatorBase {
PADDLE_THROW
(
"Not implemented"
);
}
void
CreateScope
(
const
Scope
&
scope
)
const
{
auto
sub_scopes_var
=
scope
.
FindVar
(
"SubScopes"
);
PADDLE_ENFORCE
(
sub_scopes_var
!=
nullptr
,
""
);
auto
sub_scopes
=
sub_scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
auto
&
sub_scope
=
scope
.
NewScope
();
sub_scopes
->
push_back
(
&
sub_scope
);
}
void
CreateScope
(
const
framework
::
Scope
&
scope
)
const
;
void
CreateIndexTensor
(
const
Scope
&
scope
)
const
{
auto
index_tensors_var
=
scope
.
FindVar
(
"IndexTensors"
);
PADDLE_ENFORCE
(
index_tensors_var
!=
nullptr
,
""
);
auto
&
index_tensors
=
*
index_tensors_var
->
GetMutable
<
std
::
vector
<
Tensor
*>>
();
Tensor
index_tensor
;
index_tensors
.
push_back
(
&
index_tensor
);
}
void
CreateIndexTensor
(
const
framework
::
Scope
&
scope
)
const
;
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
{
auto
sub_scopes_var
=
scope
.
FindVar
(
"SubScopes"
);
PADDLE_ENFORCE_NOT_NULL
(
sub_scopes_var
);
auto
&
sub_scopes
=
*
sub_scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
// auto& index_tensors =
// *scope.FindVar("IndexTensors")->GetMutable<std::vector<Tensor*>>();
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// Create two sub scopes for true and false branches
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
// branch
CreateScope
(
scope
);
// Create two tensors for true and false indices
// index_tensors[0] for the true branch and index_tensors[1] for the false
// branch
CreateIndexTensor
(
scope
);
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
// Create a new tensor in sub-scope for input-type tensor
Variable
*
v
=
sub_scopes
[
i
]
->
NewVar
(
input
);
Tensor
*
sub_input
=
v
->
GetMutable
<
Tensor
>
();
sub_input
->
Resize
(
scope
.
FindVar
(
input
)
->
GetMutable
<
Tensor
>
()
->
dims
());
}
// Inputs that do not require tailoring
/*for (auto& input : (*sub_net_op_[i]).Inputs()) {
// weights are located in the parent scope rather than sub scope
for (auto& var_name : input.second) {
if (!sub_scopes[i]->FindVar(var_name)) {
sub_scopes[i]->NewVar(var_name)->GetMutable<Tensor>();
}
}
}*/
// Outputs
for
(
auto
&
output
:
(
*
sub_net_op_
[
i
]).
Outputs
())
{
for
(
auto
&
var_name
:
output
.
second
)
{
sub_scopes
[
i
]
->
NewVar
(
var_name
);
}
}
// each net calls InferShape
LOG
(
INFO
)
<<
"OK 3"
;
sub_net_op_
[
i
]
->
InferShape
(
*
sub_scopes
[
i
]);
LOG
(
INFO
)
<<
"OK 4"
;
}
for
(
auto
&
output
:
Outputs
(
"Outs"
))
{
Tensor
*
tensor_t_out
=
sub_scopes
[
0
]
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
Tensor
*
tensor_f_out
=
sub_scopes
[
1
]
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
Tensor
*
tensor_out
=
scope
.
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
// check output size should be same
PADDLE_ENFORCE_EQ
(
tensor_t_out
->
dims
(),
tensor_f_out
->
dims
(),
"Outputs not of the same shape"
);
tensor_out
->
Resize
(
tensor_t_out
->
dims
());
}
LOG
(
INFO
)
<<
"OK 5"
;
}
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
;
// Set True Block
void
set_truenet
(
std
::
unique_ptr
<
OperatorBase
>
net
)
{
...
...
@@ -137,74 +61,7 @@ class CondOp : public OperatorBase {
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
auto
sub_scopes
=
scope
.
FindVar
(
"SubScopes"
)
->
Get
<
std
::
vector
<
Scope
*>>
();
auto
index_tensors
=
scope
.
FindVar
(
"IndexTensors"
)
->
Get
<
std
::
vector
<
Tensor
*>>
();
std
::
string
cond_name
=
Input
(
"Cond"
);
Variable
*
cond_var
=
scope
.
FindVar
(
cond_name
);
PADDLE_ENFORCE_NOT_NULL
(
cond_var
)
const
Tensor
*
cond
=
cond_var
->
GetMutable
<
Tensor
>
();
// Step 1: get the true/false index at runtime
// index_[0]: vector<int>, contains all index for cond[i] == true
// index_[1]: vector<int>, contains all index for cond[i] == false
for
(
int
i
=
0
;
i
<
2
;
++
i
)
index_
[
i
].
clear
();
const
bool
*
cond_data
=
cond
->
data
<
bool
>
();
for
(
int
i
=
0
;
i
<
cond
->
dims
()[
0
];
++
i
)
{
if
(
cond_data
[
i
])
index_
[
0
].
push_back
(
i
);
else
index_
[
1
].
push_back
(
i
);
}
// put index_[0] and index_[1] into two tensors:
// index_tensor_[0] and index_tensor_[1]
framework
::
DDim
dim
=
paddle
::
framework
::
make_ddim
({
0
});
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
dim
[
0
]
=
index_
[
i
].
size
();
int
*
tmp_ptr
=
index_tensors
[
i
]
->
mutable_data
<
int
>
(
dim
,
platform
::
CPUPlace
());
index_tensors
[
i
]
->
Resize
(
dim
);
memcpy
(
tmp_ptr
,
index_
[
i
].
data
(),
dim
[
0
]
*
sizeof
(
int
));
}
// Step 2: collect data by calling gather
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// i= 0/i for True and False branches respectively
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
// find Tensor
// Tensor* tensor_parent = scope.FindVar(input)->GetMutable<Tensor>();
Variable
*
v
=
scope
.
FindVar
(
input
);
Tensor
*
tensor_parent
=
v
->
GetMutable
<
Tensor
>
();
// Tensor* tensor_child =
// sub_scope_[i].FindVar(input)->GetMutable<Tensor>();
v
=
sub_scopes
[
i
]
->
FindVar
(
input
);
Tensor
*
tensor_child
=
v
->
GetMutable
<
Tensor
>
();
Gather
<
float
>
(
dev_ctx
.
GetPlace
(),
tensor_parent
,
index_tensors
[
i
],
tensor_child
);
}
}
// Step 3: run
for
(
int
i
=
0
;
i
<
2
;
++
i
)
sub_net_op_
[
i
]
->
Run
(
*
sub_scopes
[
i
],
dev_ctx
);
// Step 4: merge output results
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// i= 0/i for True and False branches respectively
// for (auto& output : GetAttr<std::vector<std::string>>("sub_outputs")) {
for
(
auto
&
output
:
Outputs
(
"Outs"
))
{
// find Tensor
Variable
*
v
=
scope
.
FindVar
(
output
);
Tensor
*
tensor_parent
=
v
->
GetMutable
<
Tensor
>
();
v
=
sub_scopes
[
i
]
->
FindVar
(
output
);
Tensor
*
tensor_child
=
v
->
GetMutable
<
Tensor
>
();
ScatterUpdate
<
float
>
(
dev_ctx
.
GetPlace
(),
tensor_child
,
index_tensors
[
i
],
tensor_parent
);
}
}
}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
;
private:
// sub_net_op_[0]: subnet_t
...
...
@@ -216,17 +73,5 @@ class CondOp : public OperatorBase {
mutable
std
::
vector
<
std
::
vector
<
int
>>
index_
;
};
/*
class CondGradientOp final : public OperatorBase {
public:
void Init() override;
virtual void InferShape(const std::shared_ptr<Scope>& scope) const
override;
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override;
};*/
}
// namespace operators
}
// namespace paddle
paddle/pybind/pybind.cc
浏览文件 @
c7db6e8d
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/framework/backward.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cond_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h"
...
...
python/paddle/v2/framework/op.py
浏览文件 @
c7db6e8d
...
...
@@ -217,7 +217,7 @@ class __RecurrentOp__(object):
class
__CondOp__
(
object
):
__proto__
=
None
type
=
'cond_op'
type
=
"cond"
def
__init__
(
self
):
# cache recurrent_op's proto
...
...
@@ -227,8 +227,8 @@ class __CondOp__(object):
self
.
__proto__
=
op_proto
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
self
.
type
not
in
args
and
'type'
not
in
kwargs
:
kwargs
[
'type'
]
=
self
.
type
if
self
.
type
not
in
args
and
"type"
not
in
kwargs
:
kwargs
[
"type"
]
=
self
.
type
# create proto
create_method
=
OpDescCreationMethod
(
self
.
__proto__
)
proto
=
create_method
(
*
args
,
**
kwargs
)
...
...
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
c7db6e8d
...
...
@@ -27,6 +27,7 @@ py_test(test_operator SRCS test_operator.py)
py_test
(
test_gaussian_random_op SRCS test_gaussian_random_op.py
)
py_test
(
test_uniform_random_op SRCS test_uniform_random_op.py
)
py_test
(
test_recurrent_op SRCS test_recurrent_op.py
)
py_test
(
test_cond_op SRCS test_cond_op.py
)
py_test
(
test_sgd_op SRCS test_sgd_op.py
)
py_test
(
test_gradient_checker SRCS test_gradient_checker.py
)
py_test
(
test_lookup_table SRCS test_lookup_table.py
)
...
...
python/paddle/v2/framework/tests/test_cond_op.py
浏览文件 @
c7db6e8d
...
...
@@ -11,15 +11,15 @@ class PySimpleCond(object):
'''
def
__init__
(
self
):
array
=
[
True
]
*
10
array
=
[
1
]
*
10
for
i
in
range
(
1
,
10
,
2
):
array
[
i
]
=
False
array
[
i
]
=
0
self
.
cond
=
np
.
array
(
array
)
self
.
x
=
np
.
ones
(
shape
=
(
10
,
1
))
def
forward
(
self
):
self
.
index_t
=
np
.
where
(
self
.
cond
)
self
.
index_f
=
np
.
where
(
self
.
cond
==
False
)
self
.
index_t
=
np
.
where
(
self
.
cond
==
1
)
self
.
index_f
=
np
.
where
(
self
.
cond
==
0
)
y_t
=
self
.
x
[
self
.
index_t
]
y_f
=
self
.
x
[
self
.
index_f
]
y_t
=
y_t
*
2.
...
...
@@ -36,7 +36,6 @@ class PySimpleCondTest(unittest.TestCase):
def
test_forward
(
self
):
output
=
self
.
condnn
.
forward
()
print
'output'
,
output
def
create_tensor
(
scope
,
name
,
shape
,
np_data
):
...
...
@@ -67,47 +66,50 @@ class TestCondOp(unittest.TestCase):
self
.
create_cond_op
()
self
.
create_sub_net
()
ctx
=
core
.
DeviceContext
.
create
(
core
.
CPUPlace
())
print
'running infer shape'
print
self
.
scope
.
find_var
(
"SubScopes"
)
self
.
condop
.
infer_shape
(
self
.
scope
)
print
'ok 2'
self
.
condop
.
run
(
self
.
scope
,
ctx
)
print
'ok 3'
return
np
.
array
(
self
.
scope
.
find_var
(
"Outs"
).
get_tensor
())
return
np
.
array
(
self
.
scope
.
find_var
(
"Out"
).
get_tensor
())
def
create_global_variables
(
self
):
x_np_data
=
self
.
py_cond
.
x
create_tensor
(
self
.
scope
,
"
x
"
,
[
10
,
1
],
x_np_data
)
cond_np_data
=
self
.
py_cond
.
cond
create_tensor
(
self
.
scope
,
"cond"
,
[
10
,
1
],
x
_np_data
)
create_tensor
(
self
.
scope
,
"
X
"
,
[
10
,
1
],
x_np_data
)
cond_np_data
=
self
.
py_cond
.
cond
.
astype
(
"int32"
)
create_tensor
(
self
.
scope
,
"cond"
,
[
10
,
1
],
cond
_np_data
)
self
.
scope
.
new_var
(
"SubScopes"
)
self
.
scope
.
new_var
(
"IndexTensors"
)
self
.
scope
.
new_var
(
"Out
s
"
)
self
.
scope
.
new_var
(
"Out"
)
def
create_cond_op
(
self
):
self
.
condop
=
CondOp
(
Cond
=
"cond"
,
Xs
=
[
"
x
"
],
Outs
=
[
'Out_final'
],
Xs
=
[
"
X
"
],
Outs
=
[
"Out"
],
SubScopes
=
"SubScopes"
,
IndexTensors
=
"IndexTensors"
)
def
create_sub_net
(
self
):
truenet
=
core
.
Net
.
create
()
scale_op_t
=
Operator
(
"scale"
,
X
=
'X'
,
Y
=
'Out'
,
scale
=
2.
)
scale_op_t
=
Operator
(
"scale"
,
X
=
'X'
,
Out
=
'Out'
,
scale
=
2.
)
truenet
.
append_op
(
scale_op_t
)
truenet
.
complete_add_op
(
True
)
self
.
condop
.
set_truenet
(
truenet
)
falsenet
=
core
.
Net
.
create
()
scale_op_t
=
Operator
(
"scale"
,
X
=
'X'
,
Y
=
'Out'
,
scale
=-
2.
)
scale_op_t
=
Operator
(
"scale"
,
X
=
'X'
,
Out
=
'Out'
,
scale
=-
2.
)
falsenet
.
append_op
(
scale_op_t
)
falsenet
.
complete_add_op
(
True
)
self
.
condop
.
set_falsenet
(
falsenet
)
def
test_forward
(
self
):
print
'test cond op forward'
py_output
=
self
.
forward
()
pd_output
=
self
.
forward
()
py_output
=
self
.
py_cond
.
forward
()
print
'pd_output'
,
pd_output
print
print
'py_output'
,
py_output
self
.
assertEqual
(
pd_output
.
shape
,
py_output
.
shape
)
print
'test passed'
return
0
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录