Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c7db6e8d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c7db6e8d
编写于
9月 13, 2017
作者:
Z
zchen0211
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cond op passed
上级
b8e75c1f
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
198 addition
and
189 deletion
+198
-189
paddle/operators/cond_op.cc
paddle/operators/cond_op.cc
+163
-3
paddle/operators/cond_op.h
paddle/operators/cond_op.h
+9
-164
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+1
-0
python/paddle/v2/framework/op.py
python/paddle/v2/framework/op.py
+3
-3
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+1
-0
python/paddle/v2/framework/tests/test_cond_op.py
python/paddle/v2/framework/tests/test_cond_op.py
+21
-19
未找到文件。
paddle/operators/cond_op.cc
浏览文件 @
c7db6e8d
...
...
@@ -13,15 +13,175 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cond_op.h"
#include <cstring>
#include <sstream>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/gather.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/scatter.h"
namespace
paddle
{
namespace
operators
{
class
CondOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
using
Scope
=
framework
::
Scope
;
using
Variable
=
framework
::
Variable
;
using
Tensor
=
framework
::
Tensor
;
using
DDim
=
framework
::
DDim
;
void
CondOp
::
CreateScope
(
const
Scope
&
scope
)
const
{
auto
sub_scopes_var
=
scope
.
FindVar
(
"SubScopes"
);
PADDLE_ENFORCE
(
sub_scopes_var
!=
nullptr
,
""
);
auto
sub_scopes
=
sub_scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
auto
&
sub_scope
=
scope
.
NewScope
();
sub_scopes
->
push_back
(
&
sub_scope
);
}
void
CondOp
::
CreateIndexTensor
(
const
Scope
&
scope
)
const
{
auto
index_tensors_var
=
scope
.
FindVar
(
"IndexTensors"
);
PADDLE_ENFORCE
(
index_tensors_var
!=
nullptr
,
""
);
auto
&
index_tensors
=
*
index_tensors_var
->
GetMutable
<
std
::
vector
<
Tensor
>>
();
index_tensors
.
push_back
(
Tensor
());
}
void
CondOp
::
InferShape
(
const
Scope
&
scope
)
const
{
auto
sub_scopes_var
=
scope
.
FindVar
(
"SubScopes"
);
PADDLE_ENFORCE_NOT_NULL
(
sub_scopes_var
);
auto
&
sub_scopes
=
*
sub_scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// Create two sub scopes for true and false branches
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
// branch
CreateScope
(
scope
);
// Create two tensors for true and false indices
// index_tensors[0] for the true branch and index_tensors[1] for the false
// branch
CreateIndexTensor
(
scope
);
PADDLE_ENFORCE
(
!
Inputs
(
"Xs"
).
empty
(),
"Inputs can't be empty"
);
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
// Create a new tensor in sub-scope for input-type tensor
Variable
*
v
=
sub_scopes
[
i
]
->
NewVar
(
input
);
Tensor
*
sub_input
=
v
->
GetMutable
<
Tensor
>
();
sub_input
->
Resize
(
scope
.
FindVar
(
input
)
->
GetMutable
<
Tensor
>
()
->
dims
());
}
for
(
auto
&
output
:
(
*
sub_net_op_
[
i
]).
Outputs
())
{
for
(
auto
&
var_name
:
output
.
second
)
{
sub_scopes
[
i
]
->
NewVar
(
var_name
);
}
}
// each net calls InferShape
sub_net_op_
[
i
]
->
InferShape
(
*
sub_scopes
[
i
]);
}
for
(
auto
&
output
:
Outputs
(
"Outs"
))
{
Tensor
*
tensor_t_out
=
sub_scopes
[
0
]
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE_NOT_NULL
(
tensor_t_out
,
"True output should be NULL"
);
Tensor
*
tensor_f_out
=
sub_scopes
[
1
]
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE_NOT_NULL
(
tensor_f_out
,
"True output should be NULL"
);
auto
*
tensor_out_var
=
scope
.
FindVar
(
output
);
PADDLE_ENFORCE_NOT_NULL
(
tensor_out_var
,
"Output not found"
);
Tensor
*
tensor_out
=
tensor_out_var
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE_NOT_NULL
(
tensor_t_out
,
"True output should be NULL"
);
// check output size should be same
PADDLE_ENFORCE_EQ
(
tensor_t_out
->
dims
(),
tensor_f_out
->
dims
(),
"Outputs not of the same shape"
);
tensor_out
->
Resize
(
tensor_t_out
->
dims
());
tensor_out
->
mutable_data
<
float
>
(
tensor_out
->
dims
(),
platform
::
CPUPlace
());
}
}
void
CondOp
::
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
auto
sub_scopes
=
scope
.
FindVar
(
"SubScopes"
)
->
Get
<
std
::
vector
<
Scope
*>>
();
auto
index_tensors
=
scope
.
FindVar
(
"IndexTensors"
)
->
Get
<
std
::
vector
<
Tensor
>>
();
std
::
string
cond_name
=
Input
(
"Cond"
);
Variable
*
cond_var
=
scope
.
FindVar
(
cond_name
);
PADDLE_ENFORCE_NOT_NULL
(
cond_var
);
const
Tensor
*
cond
=
cond_var
->
GetMutable
<
Tensor
>
();
// Step 1: get the true/false index at runtime
// index_[0]: vector<int>, contains all index for cond[i] == true
// index_[1]: vector<int>, contains all index for cond[i] == false
for
(
int
i
=
0
;
i
<
2
;
++
i
)
index_
[
i
].
clear
();
const
int
*
cond_data
=
cond
->
data
<
int
>
();
for
(
int
i
=
0
;
i
<
cond
->
dims
()[
0
];
++
i
)
{
if
(
cond_data
[
i
])
index_
[
0
].
push_back
(
i
);
else
index_
[
1
].
push_back
(
i
);
}
// put index_[0] and index_[1] into two tensors:
// index_tensor_[0] and index_tensor_[1]
DDim
dim
=
paddle
::
framework
::
make_ddim
({
0
});
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
dim
[
0
]
=
index_
[
i
].
size
();
int
*
tmp_ptr
=
index_tensors
[
i
].
mutable_data
<
int
>
(
dim
,
platform
::
CPUPlace
());
index_tensors
[
i
].
Resize
(
dim
);
memcpy
(
tmp_ptr
,
index_
[
i
].
data
(),
dim
[
0
]
*
sizeof
(
int
));
}
// Step 2: collect data by calling gather
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// i= 0/i for True and False branches respectively
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
// find Tensor
Variable
*
v
=
scope
.
FindVar
(
input
);
PADDLE_ENFORCE_NOT_NULL
(
v
);
Tensor
*
tensor_parent
=
v
->
GetMutable
<
Tensor
>
();
v
=
sub_scopes
[
i
]
->
FindVar
(
input
);
PADDLE_ENFORCE_NOT_NULL
(
v
);
Tensor
*
tensor_child
=
v
->
GetMutable
<
Tensor
>
();
// Resize child
DDim
dim
=
tensor_child
->
dims
();
dim
[
0
]
=
index_
[
i
].
size
();
tensor_child
->
Resize
(
dim
);
tensor_child
->
mutable_data
<
float
>
(
dim
,
platform
::
CPUPlace
());
Gather
<
float
>
(
dev_ctx
.
GetPlace
(),
tensor_parent
,
&
index_tensors
[
i
],
tensor_child
);
}
}
// Step 3: run
for
(
int
i
=
0
;
i
<
2
;
++
i
)
sub_net_op_
[
i
]
->
Run
(
*
sub_scopes
[
i
],
dev_ctx
);
// Step 4: merge output results
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// i= 0/i for True and False branches respectively
for
(
auto
&
output
:
Outputs
(
"Outs"
))
{
// find Tensor
Variable
*
v
=
scope
.
FindVar
(
output
);
PADDLE_ENFORCE_NOT_NULL
(
v
);
Tensor
*
tensor_parent
=
v
->
GetMutable
<
Tensor
>
();
v
=
sub_scopes
[
i
]
->
FindVar
(
output
);
PADDLE_ENFORCE_NOT_NULL
(
v
);
Tensor
*
tensor_child
=
v
->
GetMutable
<
Tensor
>
();
ScatterUpdate
<
float
>
(
dev_ctx
.
GetPlace
(),
tensor_child
,
&
index_tensors
[
i
],
tensor_parent
);
}
}
}
class
CondOpProtoAndCheckerMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
CondOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
CondOpProtoAndCheckerMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"Cond"
,
"The condition, which is a bool vector"
);
AddInput
(
"Xs"
,
"Inputs of Subnets"
).
AsDuplicable
();
...
...
@@ -41,5 +201,5 @@ Out[i] = subnet_t[i], if Cond[i] == false
}
// namespace operators
}
// namespace paddle
REGISTER_OP_WITHOUT_GRADIENT
(
cond
_op
,
paddle
::
operators
::
CondOp
,
REGISTER_OP_WITHOUT_GRADIENT
(
cond
,
paddle
::
operators
::
CondOp
,
paddle
::
operators
::
CondOpProtoAndCheckerMaker
);
paddle/operators/cond_op.h
浏览文件 @
c7db6e8d
...
...
@@ -19,22 +19,19 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/gather.h"
#include "paddle/operators/scatter.h"
#include "paddle/operators/net_op.h"
namespace
paddle
{
namespace
operators
{
using
namespace
paddle
::
framework
;
class
CondOp
:
public
OperatorBase
{
class
CondOp
:
public
framework
::
OperatorBase
{
public:
CondOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
CondOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
index_
.
resize
(
2
);
sub_net_op_
.
resize
(
2
);
LOG
(
INFO
)
<<
"Initialization Done."
;
}
CondOp
(
const
CondOp
&
o
)
...
...
@@ -44,87 +41,14 @@ class CondOp : public OperatorBase {
PADDLE_THROW
(
"Not implemented"
);
}
void
CreateScope
(
const
Scope
&
scope
)
const
{
auto
sub_scopes_var
=
scope
.
FindVar
(
"SubScopes"
);
PADDLE_ENFORCE
(
sub_scopes_var
!=
nullptr
,
""
);
auto
sub_scopes
=
sub_scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
auto
&
sub_scope
=
scope
.
NewScope
();
sub_scopes
->
push_back
(
&
sub_scope
);
}
void
CreateScope
(
const
framework
::
Scope
&
scope
)
const
;
void
CreateIndexTensor
(
const
Scope
&
scope
)
const
{
auto
index_tensors_var
=
scope
.
FindVar
(
"IndexTensors"
);
PADDLE_ENFORCE
(
index_tensors_var
!=
nullptr
,
""
);
auto
&
index_tensors
=
*
index_tensors_var
->
GetMutable
<
std
::
vector
<
Tensor
*>>
();
Tensor
index_tensor
;
index_tensors
.
push_back
(
&
index_tensor
);
}
void
CreateIndexTensor
(
const
framework
::
Scope
&
scope
)
const
;
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
{
auto
sub_scopes_var
=
scope
.
FindVar
(
"SubScopes"
);
PADDLE_ENFORCE_NOT_NULL
(
sub_scopes_var
);
auto
&
sub_scopes
=
*
sub_scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
// auto& index_tensors =
// *scope.FindVar("IndexTensors")->GetMutable<std::vector<Tensor*>>();
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// Create two sub scopes for true and false branches
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
// branch
CreateScope
(
scope
);
// Create two tensors for true and false indices
// index_tensors[0] for the true branch and index_tensors[1] for the false
// branch
CreateIndexTensor
(
scope
);
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
// Create a new tensor in sub-scope for input-type tensor
Variable
*
v
=
sub_scopes
[
i
]
->
NewVar
(
input
);
Tensor
*
sub_input
=
v
->
GetMutable
<
Tensor
>
();
sub_input
->
Resize
(
scope
.
FindVar
(
input
)
->
GetMutable
<
Tensor
>
()
->
dims
());
}
// Inputs that do not require tailoring
/*for (auto& input : (*sub_net_op_[i]).Inputs()) {
// weights are located in the parent scope rather than sub scope
for (auto& var_name : input.second) {
if (!sub_scopes[i]->FindVar(var_name)) {
sub_scopes[i]->NewVar(var_name)->GetMutable<Tensor>();
}
}
}*/
// Outputs
for
(
auto
&
output
:
(
*
sub_net_op_
[
i
]).
Outputs
())
{
for
(
auto
&
var_name
:
output
.
second
)
{
sub_scopes
[
i
]
->
NewVar
(
var_name
);
}
}
// each net calls InferShape
LOG
(
INFO
)
<<
"OK 3"
;
sub_net_op_
[
i
]
->
InferShape
(
*
sub_scopes
[
i
]);
LOG
(
INFO
)
<<
"OK 4"
;
}
for
(
auto
&
output
:
Outputs
(
"Outs"
))
{
Tensor
*
tensor_t_out
=
sub_scopes
[
0
]
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
Tensor
*
tensor_f_out
=
sub_scopes
[
1
]
->
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
Tensor
*
tensor_out
=
scope
.
FindVar
(
output
)
->
GetMutable
<
Tensor
>
();
// check output size should be same
PADDLE_ENFORCE_EQ
(
tensor_t_out
->
dims
(),
tensor_f_out
->
dims
(),
"Outputs not of the same shape"
);
tensor_out
->
Resize
(
tensor_t_out
->
dims
());
}
LOG
(
INFO
)
<<
"OK 5"
;
}
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
;
// Set True Block
void
set_truenet
(
std
::
unique_ptr
<
OperatorBase
>
net
)
{
...
...
@@ -137,74 +61,7 @@ class CondOp : public OperatorBase {
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
auto
sub_scopes
=
scope
.
FindVar
(
"SubScopes"
)
->
Get
<
std
::
vector
<
Scope
*>>
();
auto
index_tensors
=
scope
.
FindVar
(
"IndexTensors"
)
->
Get
<
std
::
vector
<
Tensor
*>>
();
std
::
string
cond_name
=
Input
(
"Cond"
);
Variable
*
cond_var
=
scope
.
FindVar
(
cond_name
);
PADDLE_ENFORCE_NOT_NULL
(
cond_var
)
const
Tensor
*
cond
=
cond_var
->
GetMutable
<
Tensor
>
();
// Step 1: get the true/false index at runtime
// index_[0]: vector<int>, contains all index for cond[i] == true
// index_[1]: vector<int>, contains all index for cond[i] == false
for
(
int
i
=
0
;
i
<
2
;
++
i
)
index_
[
i
].
clear
();
const
bool
*
cond_data
=
cond
->
data
<
bool
>
();
for
(
int
i
=
0
;
i
<
cond
->
dims
()[
0
];
++
i
)
{
if
(
cond_data
[
i
])
index_
[
0
].
push_back
(
i
);
else
index_
[
1
].
push_back
(
i
);
}
// put index_[0] and index_[1] into two tensors:
// index_tensor_[0] and index_tensor_[1]
framework
::
DDim
dim
=
paddle
::
framework
::
make_ddim
({
0
});
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
dim
[
0
]
=
index_
[
i
].
size
();
int
*
tmp_ptr
=
index_tensors
[
i
]
->
mutable_data
<
int
>
(
dim
,
platform
::
CPUPlace
());
index_tensors
[
i
]
->
Resize
(
dim
);
memcpy
(
tmp_ptr
,
index_
[
i
].
data
(),
dim
[
0
]
*
sizeof
(
int
));
}
// Step 2: collect data by calling gather
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// i= 0/i for True and False branches respectively
for
(
auto
&
input
:
Inputs
(
"Xs"
))
{
// find Tensor
// Tensor* tensor_parent = scope.FindVar(input)->GetMutable<Tensor>();
Variable
*
v
=
scope
.
FindVar
(
input
);
Tensor
*
tensor_parent
=
v
->
GetMutable
<
Tensor
>
();
// Tensor* tensor_child =
// sub_scope_[i].FindVar(input)->GetMutable<Tensor>();
v
=
sub_scopes
[
i
]
->
FindVar
(
input
);
Tensor
*
tensor_child
=
v
->
GetMutable
<
Tensor
>
();
Gather
<
float
>
(
dev_ctx
.
GetPlace
(),
tensor_parent
,
index_tensors
[
i
],
tensor_child
);
}
}
// Step 3: run
for
(
int
i
=
0
;
i
<
2
;
++
i
)
sub_net_op_
[
i
]
->
Run
(
*
sub_scopes
[
i
],
dev_ctx
);
// Step 4: merge output results
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
// i= 0/i for True and False branches respectively
// for (auto& output : GetAttr<std::vector<std::string>>("sub_outputs")) {
for
(
auto
&
output
:
Outputs
(
"Outs"
))
{
// find Tensor
Variable
*
v
=
scope
.
FindVar
(
output
);
Tensor
*
tensor_parent
=
v
->
GetMutable
<
Tensor
>
();
v
=
sub_scopes
[
i
]
->
FindVar
(
output
);
Tensor
*
tensor_child
=
v
->
GetMutable
<
Tensor
>
();
ScatterUpdate
<
float
>
(
dev_ctx
.
GetPlace
(),
tensor_child
,
index_tensors
[
i
],
tensor_parent
);
}
}
}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
;
private:
// sub_net_op_[0]: subnet_t
...
...
@@ -216,17 +73,5 @@ class CondOp : public OperatorBase {
mutable
std
::
vector
<
std
::
vector
<
int
>>
index_
;
};
/*
class CondGradientOp final : public OperatorBase {
public:
void Init() override;
virtual void InferShape(const std::shared_ptr<Scope>& scope) const
override;
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override;
};*/
}
// namespace operators
}
// namespace paddle
paddle/pybind/pybind.cc
浏览文件 @
c7db6e8d
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/framework/backward.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cond_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h"
...
...
python/paddle/v2/framework/op.py
浏览文件 @
c7db6e8d
...
...
@@ -217,7 +217,7 @@ class __RecurrentOp__(object):
class
__CondOp__
(
object
):
__proto__
=
None
type
=
'cond_op'
type
=
"cond"
def
__init__
(
self
):
# cache recurrent_op's proto
...
...
@@ -227,8 +227,8 @@ class __CondOp__(object):
self
.
__proto__
=
op_proto
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
self
.
type
not
in
args
and
'type'
not
in
kwargs
:
kwargs
[
'type'
]
=
self
.
type
if
self
.
type
not
in
args
and
"type"
not
in
kwargs
:
kwargs
[
"type"
]
=
self
.
type
# create proto
create_method
=
OpDescCreationMethod
(
self
.
__proto__
)
proto
=
create_method
(
*
args
,
**
kwargs
)
...
...
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
c7db6e8d
...
...
@@ -27,6 +27,7 @@ py_test(test_operator SRCS test_operator.py)
py_test
(
test_gaussian_random_op SRCS test_gaussian_random_op.py
)
py_test
(
test_uniform_random_op SRCS test_uniform_random_op.py
)
py_test
(
test_recurrent_op SRCS test_recurrent_op.py
)
py_test
(
test_cond_op SRCS test_cond_op.py
)
py_test
(
test_sgd_op SRCS test_sgd_op.py
)
py_test
(
test_gradient_checker SRCS test_gradient_checker.py
)
py_test
(
test_lookup_table SRCS test_lookup_table.py
)
...
...
python/paddle/v2/framework/tests/test_cond_op.py
浏览文件 @
c7db6e8d
...
...
@@ -11,15 +11,15 @@ class PySimpleCond(object):
'''
def
__init__
(
self
):
array
=
[
True
]
*
10
array
=
[
1
]
*
10
for
i
in
range
(
1
,
10
,
2
):
array
[
i
]
=
False
array
[
i
]
=
0
self
.
cond
=
np
.
array
(
array
)
self
.
x
=
np
.
ones
(
shape
=
(
10
,
1
))
def
forward
(
self
):
self
.
index_t
=
np
.
where
(
self
.
cond
)
self
.
index_f
=
np
.
where
(
self
.
cond
==
False
)
self
.
index_t
=
np
.
where
(
self
.
cond
==
1
)
self
.
index_f
=
np
.
where
(
self
.
cond
==
0
)
y_t
=
self
.
x
[
self
.
index_t
]
y_f
=
self
.
x
[
self
.
index_f
]
y_t
=
y_t
*
2.
...
...
@@ -36,7 +36,6 @@ class PySimpleCondTest(unittest.TestCase):
def
test_forward
(
self
):
output
=
self
.
condnn
.
forward
()
print
'output'
,
output
def
create_tensor
(
scope
,
name
,
shape
,
np_data
):
...
...
@@ -67,47 +66,50 @@ class TestCondOp(unittest.TestCase):
self
.
create_cond_op
()
self
.
create_sub_net
()
ctx
=
core
.
DeviceContext
.
create
(
core
.
CPUPlace
())
print
'running infer shape'
print
self
.
scope
.
find_var
(
"SubScopes"
)
self
.
condop
.
infer_shape
(
self
.
scope
)
print
'ok 2'
self
.
condop
.
run
(
self
.
scope
,
ctx
)
print
'ok 3'
return
np
.
array
(
self
.
scope
.
find_var
(
"Outs"
).
get_tensor
())
return
np
.
array
(
self
.
scope
.
find_var
(
"Out"
).
get_tensor
())
def
create_global_variables
(
self
):
x_np_data
=
self
.
py_cond
.
x
create_tensor
(
self
.
scope
,
"
x
"
,
[
10
,
1
],
x_np_data
)
cond_np_data
=
self
.
py_cond
.
cond
create_tensor
(
self
.
scope
,
"cond"
,
[
10
,
1
],
x
_np_data
)
create_tensor
(
self
.
scope
,
"
X
"
,
[
10
,
1
],
x_np_data
)
cond_np_data
=
self
.
py_cond
.
cond
.
astype
(
"int32"
)
create_tensor
(
self
.
scope
,
"cond"
,
[
10
,
1
],
cond
_np_data
)
self
.
scope
.
new_var
(
"SubScopes"
)
self
.
scope
.
new_var
(
"IndexTensors"
)
self
.
scope
.
new_var
(
"Out
s
"
)
self
.
scope
.
new_var
(
"Out"
)
def
create_cond_op
(
self
):
self
.
condop
=
CondOp
(
Cond
=
"cond"
,
Xs
=
[
"
x
"
],
Outs
=
[
'Out_final'
],
Xs
=
[
"
X
"
],
Outs
=
[
"Out"
],
SubScopes
=
"SubScopes"
,
IndexTensors
=
"IndexTensors"
)
def
create_sub_net
(
self
):
truenet
=
core
.
Net
.
create
()
scale_op_t
=
Operator
(
"scale"
,
X
=
'X'
,
Y
=
'Out'
,
scale
=
2.
)
scale_op_t
=
Operator
(
"scale"
,
X
=
'X'
,
Out
=
'Out'
,
scale
=
2.
)
truenet
.
append_op
(
scale_op_t
)
truenet
.
complete_add_op
(
True
)
self
.
condop
.
set_truenet
(
truenet
)
falsenet
=
core
.
Net
.
create
()
scale_op_t
=
Operator
(
"scale"
,
X
=
'X'
,
Y
=
'Out'
,
scale
=-
2.
)
scale_op_t
=
Operator
(
"scale"
,
X
=
'X'
,
Out
=
'Out'
,
scale
=-
2.
)
falsenet
.
append_op
(
scale_op_t
)
falsenet
.
complete_add_op
(
True
)
self
.
condop
.
set_falsenet
(
falsenet
)
def
test_forward
(
self
):
print
'test cond op forward'
py_output
=
self
.
forward
()
pd_output
=
self
.
forward
()
py_output
=
self
.
py_cond
.
forward
()
print
'pd_output'
,
pd_output
print
print
'py_output'
,
py_output
self
.
assertEqual
(
pd_output
.
shape
,
py_output
.
shape
)
print
'test passed'
return
0
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录