Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d4b67e16
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d4b67e16
编写于
5月 10, 2019
作者:
Z
zhoukunsheng
提交者:
Tao Luo
5月 10, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Where Op(#16793)
上级
1bfff020
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
363 addition
and
0 deletion
+363
-0
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/where_op.cc
paddle/fluid/operators/where_op.cc
+58
-0
paddle/fluid/operators/where_op.cu
paddle/fluid/operators/where_op.cu
+81
-0
paddle/fluid/operators/where_op.h
paddle/fluid/operators/where_op.h
+95
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+36
-0
python/paddle/fluid/tests/unittests/test_where.py
python/paddle/fluid/tests/unittests/test_where.py
+92
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
d4b67e16
...
...
@@ -234,6 +234,7 @@ paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l
paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', '132b6e74ff642a392bd6b14c10aedc65'))
paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393'))
paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'a07a44c2bacdcd09c1f5f35a96a0514e'))
paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', '3126e3039e752ce26077f1efaca355c6'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', 'adf285346e23316097f7789b572491e9'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cf12066a3139026119f97f9d4381a1bd'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e'))
...
...
paddle/fluid/operators/where_op.cc
0 → 100644
浏览文件 @
d4b67e16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/where_op.h"
namespace
paddle
{
namespace
operators
{
class
WhereOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Condition"
),
"Input(Condition) of WhereOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputDim
(
"Condition"
).
size
()
>=
1
,
"Input(Condition) should have number of dimension at least 1"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(OUt) of WhereOp should not be null."
);
ctx
->
SetOutputDim
(
"Out"
,
{
-
1
,
ctx
->
GetInputDim
(
"Condition"
).
size
()});
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
output_type
=
framework
::
proto
::
VarType
::
INT64
;
return
framework
::
OpKernelType
(
output_type
,
ctx
.
device_context
());
}
};
class
WhereOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Condition"
,
"A bool tensor whose rank is at least 1"
);
AddOutput
(
"Out"
,
"An int64 tensor of rank 2"
);
AddComment
(
R"DOC(
Return a int64 tensor with rank 2, specifying the coordinate of true element in `Condition`.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
where
,
ops
::
WhereOp
,
ops
::
WhereOpMaker
);
REGISTER_OP_CPU_KERNEL
(
where
,
ops
::
CPUWhereKernel
<
int64_t
>
);
paddle/fluid/operators/where_op.cu
0 → 100644
浏览文件 @
d4b67e16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
using
CUDADeviceContext
=
paddle
::
platform
::
CUDADeviceContext
;
template
<
typename
T
>
class
CUDAWhereKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
condition
=
context
.
Input
<
framework
::
Tensor
>
(
"Condition"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
// TODO(zhoukunsheng): Should optimize to ensure GPU is faster than CPU.
framework
::
Tensor
cond_cpu
;
framework
::
TensorCopy
(
*
condition
,
platform
::
CPUPlace
(),
&
cond_cpu
);
const
bool
*
cond_data
=
cond_cpu
.
data
<
bool
>
();
int64_t
numel
=
cond_cpu
.
numel
();
auto
dims
=
cond_cpu
.
dims
();
int
rank
=
dims
.
size
();
thrust
::
host_vector
<
int
>
h_true_index
;
for
(
int64_t
i
=
0
;
i
<
numel
;
i
++
)
{
if
(
cond_data
[
i
])
{
h_true_index
.
push_back
(
i
);
}
}
thrust
::
device_vector
<
int
>
d_true_index
=
h_true_index
;
int
*
ptr_true_index
=
thrust
::
raw_pointer_cast
(
d_true_index
.
data
());
size_t
true_num
=
h_true_index
.
size
();
out
->
Resize
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
true_num
),
rank
}));
auto
out_ptr
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
true_num
==
0
)
{
return
;
}
thrust
::
host_vector
<
int
>
h_stride
(
rank
,
0
);
h_stride
[
rank
-
1
]
=
1
;
for
(
int
i
=
rank
-
2
;
i
>=
0
;
i
--
)
{
h_stride
[
i
]
=
h_stride
[
i
+
1
]
*
dims
[
i
+
1
];
}
thrust
::
device_vector
<
int
>
d_stride
=
h_stride
;
int
*
ptr_stride
=
thrust
::
raw_pointer_cast
(
d_stride
.
data
());
auto
&
dev_ctx
=
context
.
template
device_context
<
CUDADeviceContext
>();
WhereFunctor
<
int
*>
functor
(
ptr_true_index
,
true_num
,
ptr_stride
,
rank
,
out_ptr
);
platform
::
ForRange
<
CUDADeviceContext
>
for_range
(
dev_ctx
,
true_num
);
for_range
(
functor
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
where
,
ops
::
CUDAWhereKernel
<
int64_t
>
);
paddle/fluid/operators/where_op.h
0 → 100644
浏览文件 @
d4b67e16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
WhereFunctor
{
WhereFunctor
(
const
T
&
true_index
,
int
true_num
,
const
T
&
stride
,
int
rank
,
int64_t
*
out
)
:
true_index_
(
true_index
),
true_num_
(
true_num
),
stride_
(
stride
),
rank_
(
rank
),
out_ptr_
(
out
)
{}
HOSTDEVICE
void
operator
()(
size_t
idx
)
const
{
int
index
=
true_index_
[
idx
];
for
(
int
j
=
0
;
j
<
rank_
;
j
++
)
{
out_ptr_
[
idx
*
rank_
+
j
]
=
index
/
stride_
[
j
];
index
-=
out_ptr_
[
idx
*
rank_
+
j
]
*
stride_
[
j
];
}
}
const
T
true_index_
;
int
true_num_
;
const
T
stride_
;
int
rank_
;
int64_t
*
out_ptr_
;
};
using
CPUDeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
template
<
typename
T
>
class
CPUWhereKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
condition
=
context
.
Input
<
framework
::
Tensor
>
(
"Condition"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
const
bool
*
cond_data
=
condition
->
data
<
bool
>
();
auto
numel
=
condition
->
numel
();
auto
dims
=
condition
->
dims
();
const
int
rank
=
dims
.
size
();
std
::
vector
<
int
>
true_index
;
for
(
auto
i
=
0
;
i
<
numel
;
i
++
)
{
if
(
cond_data
[
i
])
{
true_index
.
push_back
(
i
);
}
}
auto
true_num
=
true_index
.
size
();
out
->
Resize
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
true_num
),
rank
}));
auto
out_ptr
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
true_num
==
0
)
{
return
;
}
std
::
vector
<
int
>
stride
(
rank
);
stride
[
rank
-
1
]
=
1
;
for
(
int
i
=
rank
-
2
;
i
>=
0
;
i
--
)
{
stride
[
i
]
=
stride
[
i
+
1
]
*
dims
[
i
+
1
];
}
auto
&
dev_ctx
=
context
.
template
device_context
<
CPUDeviceContext
>();
WhereFunctor
<
int
*>
functor
(
true_index
.
data
(),
true_num
,
stride
.
data
(),
rank
,
out_ptr
);
platform
::
ForRange
<
CPUDeviceContext
>
for_range
(
dev_ctx
,
true_num
);
for_range
(
functor
);
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/nn.py
浏览文件 @
d4b67e16
...
...
@@ -200,6 +200,7 @@ __all__ = [
'pixel_shuffle'
,
'fsp_matrix'
,
'continuous_value_model'
,
'where'
,
]
kIgnoreIndex
=
-
100
...
...
@@ -11341,3 +11342,38 @@ def continuous_value_model(input, cvm, use_cvm=True):
outputs
=
{
'Y'
:
[
out
]},
attrs
=
{
"use_cvm"
:
use_cvm
})
return
out
def
where
(
condition
):
"""
Return an int64 tensor with rank 2, specifying the coordinate of true element in `condition`.
Output's first dimension is the number of true element, second dimension is rank(number of dimension) of `condition`.
If there is zero true element, then an empty tensor will be generated.
Args:
condition(Variable): A bool tensor with rank at least 1.
Returns:
Variable: The tensor variable storing a 2-D tensor.
Examples:
.. code-block:: python
# condition is a tensor [True, False, True]
out = fluid.layers.where(condition) # [[0], [2]]
# condition is a tensor [[True, False], [False, True]]
out = fluid.layers.where(condition) # [[0, 0], [1, 1]]
# condition is a tensor [False, False, False]
out = fluid.layers.where(condition) # [[]]
"""
helper
=
LayerHelper
(
"where"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
core
.
VarDesc
.
VarType
.
INT64
)
helper
.
append_op
(
type
=
'where'
,
inputs
=
{
'Condition'
:
condition
},
outputs
=
{
'Out'
:
[
out
]})
return
out
python/paddle/fluid/tests/unittests/test_where.py
0 → 100644
浏览文件 @
d4b67e16
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
class
TestWhereOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"where"
self
.
init_config
()
def
test_check_output
(
self
):
self
.
check_output
()
def
init_config
(
self
):
self
.
inputs
=
{
'Condition'
:
np
.
array
([
True
,
False
,
True
]),
}
self
.
outputs
=
{
'Out'
:
np
.
array
([[
0
],
[
2
]],
dtype
=
'int64'
)}
class
TestAllFalse
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
op_type
=
"where"
self
.
init_config
()
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
condition
=
scope
.
var
(
'Condition'
).
get_tensor
()
condition
.
set
(
self
.
cond_data
,
place
)
out
=
scope
.
var
(
"Out"
).
get_tensor
()
out
.
set
(
np
.
full
(
self
.
shape
,
0
).
astype
(
'int64'
),
place
)
op
=
Operator
(
"where"
,
Condition
=
"Condition"
,
Out
=
"Out"
)
op
.
run
(
scope
,
place
)
out_array
=
np
.
array
(
out
)
self
.
assertTrue
((
out_array
==
self
.
out_data
).
all
())
def
init_config
(
self
):
self
.
cond_data
=
np
.
array
([
False
,
False
,
False
])
self
.
shape
=
(
3
,
1
)
self
.
out_data
=
np
.
array
([],
dtype
=
'int64'
)
def
test_all_false
(
self
):
self
.
check_with_place
(
core
.
CPUPlace
())
if
core
.
is_compiled_with_cuda
():
self
.
check_with_place
(
core
.
CUDAPlace
(
0
))
class
TestRank2
(
TestWhereOp
):
def
init_config
(
self
):
self
.
inputs
=
{
'Condition'
:
np
.
array
([[
True
,
False
],
[
False
,
True
]]),
}
self
.
outputs
=
{
'Out'
:
np
.
array
([[
0
,
0
],
[
1
,
1
]],
dtype
=
'int64'
)}
class
TestRank3
(
TestWhereOp
):
def
init_config
(
self
):
self
.
inputs
=
{
'Condition'
:
np
.
array
([[[
True
,
False
],
[
False
,
True
]],
[[
False
,
True
],
[
True
,
False
]],
[[
False
,
False
],
[
False
,
True
]]]),
}
self
.
outputs
=
{
'Out'
:
np
.
array
(
[[
0
,
0
,
0
],
[
0
,
1
,
1
],
[
1
,
0
,
1
],
[
1
,
1
,
0
],
[
2
,
1
,
1
]],
dtype
=
'int64'
)
}
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录