Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cd28cddb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cd28cddb
编写于
3月 09, 2022
作者:
Z
zyfncg
提交者:
GitHub
3月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PHI] Move set_value kernel to phi (#40195)
* save code * fix bug of set_value * add coverage test
上级
63fb0347
变更
14
展开全部
隐藏空白更改
内联
并排
Showing
14 changed file
with
1701 addition
and
212 deletion
+1701
-212
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+64
-1
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+3
-1
paddle/fluid/imperative/execution_context.h
paddle/fluid/imperative/execution_context.h
+5
-0
paddle/fluid/imperative/prepared_operator.h
paddle/fluid/imperative/prepared_operator.h
+60
-1
paddle/fluid/operators/set_value_op.cc
paddle/fluid/operators/set_value_op.cc
+0
-7
paddle/fluid/operators/set_value_op.cu
paddle/fluid/operators/set_value_op.cu
+0
-7
paddle/fluid/operators/set_value_op.h
paddle/fluid/operators/set_value_op.h
+0
-195
paddle/phi/core/kernel_utils.h
paddle/phi/core/kernel_utils.h
+1
-0
paddle/phi/kernels/cpu/set_value_kernel.cc
paddle/phi/kernels/cpu/set_value_kernel.cc
+38
-0
paddle/phi/kernels/gpu/set_value_kernel.cu
paddle/phi/kernels/gpu/set_value_kernel.cu
+38
-0
paddle/phi/kernels/impl/set_value_kernel_impl.h
paddle/phi/kernels/impl/set_value_kernel_impl.h
+337
-0
paddle/phi/kernels/set_value_kernel.h
paddle/phi/kernels/set_value_kernel.h
+49
-0
paddle/phi/ops/compat/set_value_sig.cc
paddle/phi/ops/compat/set_value_sig.cc
+736
-0
paddle/phi/tests/ops/test_op_signature.cc
paddle/phi/tests/ops/test_op_signature.cc
+370
-0
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
cd28cddb
...
@@ -539,6 +539,20 @@ bool ExecutionContext::HasInput(const std::string& name) const {
...
@@ -539,6 +539,20 @@ bool ExecutionContext::HasInput(const std::string& name) const {
return
var
!=
nullptr
;
return
var
!=
nullptr
;
}
}
bool
ExecutionContext
::
HasInputs
(
const
std
::
string
&
name
)
const
{
const
auto
&
ins
=
ctx_
.
inputs
;
auto
it
=
ins
.
find
(
name
);
if
(
it
==
ins
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
for
(
const
auto
*
input
:
it
->
second
)
{
if
(
input
==
nullptr
)
{
return
false
;
}
}
return
true
;
}
bool
ExecutionContext
::
HasOutput
(
const
std
::
string
&
name
)
const
{
bool
ExecutionContext
::
HasOutput
(
const
std
::
string
&
name
)
const
{
auto
*
var
=
OutputVar
(
name
);
auto
*
var
=
OutputVar
(
name
);
return
var
!=
nullptr
;
return
var
!=
nullptr
;
...
@@ -2189,6 +2203,51 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2189,6 +2203,51 @@ void OperatorWithKernel::BuildPhiKernelContext(
std
::
move
(
experimental
::
MakePhiScalarFromVar
(
*
ins_vector
.
front
())));
std
::
move
(
experimental
::
MakePhiScalarFromVar
(
*
ins_vector
.
front
())));
}
}
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
std
::
vector
<
phi
::
Scalar
>
)))
{
auto
&
attr
=
Attrs
().
at
(
attr_names
[
i
]);
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int32_t
>
)))
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
}
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
}
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
float
>
)))
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
}
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
double
>
)))
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
double
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
}
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct KernelContext."
,
attr_names
[
i
]));
}
}
else
{
}
else
{
// TODO(chenweihang): support other attrs later
// TODO(chenweihang): support other attrs later
auto
&
attr
=
Attrs
().
at
(
attr_names
[
i
]);
auto
&
attr
=
Attrs
().
at
(
attr_names
[
i
]);
...
@@ -2212,7 +2271,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2212,7 +2271,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
else
if
(
attr_defs
[
i
].
type_index
==
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
if
(
std
::
type_index
(
attr
.
type
())
==
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int
>
)))
{
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
pt_kernel_context
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int
>
)))
{
// Emplace Back Attr according to the type of Phi_Kernel args.
// Emplace Back Attr according to the type of Phi_Kernel args.
const
auto
&
vector_int_attr
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
attr
);
const
auto
&
vector_int_attr
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
attr
);
const
std
::
vector
<
int64_t
>
vector_int64_attr
(
vector_int_attr
.
begin
(),
const
std
::
vector
<
int64_t
>
vector_int64_attr
(
vector_int_attr
.
begin
(),
...
...
paddle/fluid/framework/operator.h
浏览文件 @
cd28cddb
...
@@ -295,6 +295,8 @@ class ExecutionContext {
...
@@ -295,6 +295,8 @@ class ExecutionContext {
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
;
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
;
virtual
bool
HasInputs
(
const
std
::
string
&
name
)
const
;
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
;
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
;
virtual
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
virtual
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
...
@@ -449,7 +451,7 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
...
@@ -449,7 +451,7 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
:
ctx_
(
ctx
)
{}
:
ctx_
(
ctx
)
{}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
return
ctx_
.
HasInput
(
name
);
return
ctx_
.
HasInput
s
(
name
);
}
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
...
...
paddle/fluid/imperative/execution_context.h
浏览文件 @
cd28cddb
...
@@ -133,6 +133,11 @@ class DygraphExecutionContext : public framework::ExecutionContext {
...
@@ -133,6 +133,11 @@ class DygraphExecutionContext : public framework::ExecutionContext {
return
(
it
!=
var_map_in_
.
end
()
&&
it
->
second
.
size
()
>
0
);
return
(
it
!=
var_map_in_
.
end
()
&&
it
->
second
.
size
()
>
0
);
}
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
var_map_in_
.
find
(
name
);
return
(
it
!=
var_map_in_
.
end
()
&&
it
->
second
.
size
()
>
0
);
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
var_map_out_
.
find
(
name
);
auto
it
=
var_map_out_
.
find
(
name
);
return
(
it
!=
var_map_out_
.
end
()
&&
it
->
second
.
size
()
>
0
);
return
(
it
!=
var_map_out_
.
end
()
&&
it
->
second
.
size
()
>
0
);
...
...
paddle/fluid/imperative/prepared_operator.h
浏览文件 @
cd28cddb
...
@@ -332,6 +332,7 @@ void BuildDygraphPhiKernelContext(
...
@@ -332,6 +332,7 @@ void BuildDygraphPhiKernelContext(
}
}
for
(
size_t
i
=
0
;
i
<
attr_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
attr_names
.
size
();
++
i
)
{
VLOG
(
1
)
<<
"############## attr_name: "
<<
i
<<
" : "
<<
attr_names
[
i
];
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
phi
::
ScalarArray
)))
{
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
phi
::
ScalarArray
)))
{
if
(
attrs
.
find
(
attr_names
[
i
])
!=
if
(
attrs
.
find
(
attr_names
[
i
])
!=
attrs
.
end
())
{
// shape is in the attribute
attrs
.
end
())
{
// shape is in the attribute
...
@@ -409,6 +410,60 @@ void BuildDygraphPhiKernelContext(
...
@@ -409,6 +410,60 @@ void BuildDygraphPhiKernelContext(
experimental
::
MakePhiScalarFromVar
(
ins_vector
[
0
]
->
Var
())));
experimental
::
MakePhiScalarFromVar
(
ins_vector
[
0
]
->
Var
())));
}
}
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
std
::
vector
<
phi
::
Scalar
>
)))
{
auto
&
attr
=
GetAttr
(
attrs
,
default_attrs
,
attr_names
[
i
]);
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int32_t
>
)))
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
int32_t
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
}
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
}
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
float
>
)))
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
}
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
double
>
)))
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
double
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
}
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
bool
>
)))
{
const
auto
&
vec
=
BOOST_GET_CONST
(
std
::
vector
<
bool
>
,
attr
);
std
::
vector
<
phi
::
Scalar
>
scalar_list
;
scalar_list
.
reserve
(
vec
.
size
());
for
(
const
auto
&
val
:
vec
)
{
scalar_list
.
emplace_back
(
val
);
}
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
scalar_list
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct KernelContext."
,
attr_names
[
i
]));
}
}
else
{
}
else
{
// TODO(chenweihang): support other attrs later
// TODO(chenweihang): support other attrs later
auto
&
attr
=
GetAttr
(
attrs
,
default_attrs
,
attr_names
[
i
]);
auto
&
attr
=
GetAttr
(
attrs
,
default_attrs
,
attr_names
[
i
]);
...
@@ -432,7 +487,11 @@ void BuildDygraphPhiKernelContext(
...
@@ -432,7 +487,11 @@ void BuildDygraphPhiKernelContext(
}
else
if
(
attr_defs
[
i
].
type_index
==
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
if
(
std
::
type_index
(
attr
.
type
())
==
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int
>
)))
{
std
::
type_index
(
typeid
(
std
::
vector
<
int64_t
>
)))
{
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
attr
));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
std
::
vector
<
int
>
)))
{
// Emplace Back Attr according to the type of Phi_Kernel args.
// Emplace Back Attr according to the type of Phi_Kernel args.
const
auto
&
vector_int_attr
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
attr
);
const
auto
&
vector_int_attr
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
attr
);
const
std
::
vector
<
int64_t
>
vector_int64_attr
(
vector_int_attr
.
begin
(),
const
std
::
vector
<
int64_t
>
vector_int64_attr
(
vector_int_attr
.
begin
(),
...
...
paddle/fluid/operators/set_value_op.cc
浏览文件 @
cd28cddb
...
@@ -241,13 +241,6 @@ REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
...
@@ -241,13 +241,6 @@ REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
ops
::
SetValueGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
SetValueGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
SetValueOpInplaceInferer
);
ops
::
SetValueOpInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
set_value
,
ops
::
SetValueKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
SetValueKernel
<
plat
::
CPUDeviceContext
,
int64_t
>
,
ops
::
SetValueKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
SetValueKernel
<
plat
::
CPUDeviceContext
,
double
>
,
ops
::
SetValueKernel
<
plat
::
CPUDeviceContext
,
bool
>
);
REGISTER_OPERATOR
(
set_value_grad
,
ops
::
SetValueGrad
);
REGISTER_OPERATOR
(
set_value_grad
,
ops
::
SetValueGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
...
...
paddle/fluid/operators/set_value_op.cu
浏览文件 @
cd28cddb
...
@@ -16,13 +16,6 @@
...
@@ -16,13 +16,6 @@
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
set_value
,
ops
::
SetValueKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SetValueKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
SetValueKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SetValueKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SetValueKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
set_value_grad
,
set_value_grad
,
ops
::
SetValueGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SetValueGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
...
...
paddle/fluid/operators/set_value_op.h
浏览文件 @
cd28cddb
...
@@ -121,201 +121,6 @@ inline void CheckIsDimsMatch(const framework::DDim first,
...
@@ -121,201 +121,6 @@ inline void CheckIsDimsMatch(const framework::DDim first,
"of target shape: %d, but now shape is %d."
,
"of target shape: %d, but now shape is %d."
,
second
.
to_str
(),
first
.
to_str
()));
second
.
to_str
(),
first
.
to_str
()));
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
SetValueKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
int
rank
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
dims
().
size
();
// TODO(liym27): A more elegent code to do this. C++ has to make template
// integer as constant, but we had better have alternative writing in the
// future.
switch
(
rank
)
{
case
1
:
SetValueCompute
<
1
>
(
ctx
);
break
;
case
2
:
SetValueCompute
<
2
>
(
ctx
);
break
;
case
3
:
SetValueCompute
<
3
>
(
ctx
);
break
;
case
4
:
SetValueCompute
<
4
>
(
ctx
);
break
;
case
5
:
SetValueCompute
<
5
>
(
ctx
);
break
;
case
6
:
SetValueCompute
<
6
>
(
ctx
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The rank of input should be less than 7, but received %d."
,
rank
));
}
}
private:
template
<
size_t
D
>
void
SetValueCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
);
auto
*
value_tensor
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"ValueTensor"
);
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
starts_tensor_list
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"StartsTensorList"
);
auto
ends_tensor_list
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"EndsTensorList"
);
auto
steps_tensor_list
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"StepsTensorList"
);
auto
axes
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"axes"
);
auto
starts
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"starts"
);
auto
ends
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"ends"
);
auto
steps
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"steps"
);
auto
shape
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"shape"
);
auto
decrease_axes
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"decrease_axes"
);
auto
none_axes
=
ctx
.
Attr
<
std
::
vector
<
int64_t
>>
(
"none_axes"
);
if
(
!
starts_tensor_list
.
empty
())
{
starts
=
GetDataFromTensorList
<
int64_t
>
(
starts_tensor_list
);
}
if
(
!
ends_tensor_list
.
empty
())
{
ends
=
GetDataFromTensorList
<
int64_t
>
(
ends_tensor_list
);
}
if
(
!
steps_tensor_list
.
empty
())
{
steps
=
GetDataFromTensorList
<
int64_t
>
(
steps_tensor_list
);
}
auto
in_dims
=
in
->
dims
();
CheckAndUpdateSliceAttrs
(
in_dims
,
axes
,
&
starts
,
&
ends
,
&
steps
);
auto
slice_dims
=
GetSliceDims
(
in_dims
,
axes
,
starts
,
ends
,
&
steps
);
auto
decrease_slice_dims
=
GetDecreasedDims
(
slice_dims
,
decrease_axes
);
auto
slice_dims_for_assign
=
decrease_slice_dims
;
if
(
!
none_axes
.
empty
())
{
std
::
vector
<
int64_t
>
slice_dims_with_none
;
size_t
none_axes_cur
=
0
,
decrease_axes_cur
=
0
;
for
(
int
i
=
0
;
i
<
slice_dims
.
size
();
++
i
)
{
while
(
none_axes_cur
<
none_axes
.
size
()
&&
none_axes
[
none_axes_cur
]
<=
i
)
{
slice_dims_with_none
.
push_back
(
1
);
none_axes_cur
++
;
}
if
(
decrease_axes_cur
<
decrease_axes
.
size
()
&&
decrease_axes
[
decrease_axes_cur
]
==
i
)
{
decrease_axes_cur
++
;
}
else
{
slice_dims_with_none
.
push_back
(
slice_dims
[
i
]);
}
}
while
(
none_axes_cur
<
none_axes
.
size
())
{
slice_dims_with_none
.
push_back
(
1
);
none_axes_cur
++
;
}
slice_dims_for_assign
=
phi
::
make_ddim
(
slice_dims_with_none
);
}
auto
place
=
ctx
.
GetPlace
();
auto
&
eigen_place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
// Here copy data from input to avoid data loss at PE and Graph level.
// TODO(liym27): Speed up in the future version.
// - Q: Why don't call ShareDataWith to speed up?
// - A: Because it's not supported to ShareDataWith on OP's input and output
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
// - Q: Why don't delete Input, after all, the input and output are the same
// Tensor at program level?
// - A: If deleting Input, the graph will be complex, such as there will
// be two ops points to the output in graph: op1 -> output <- set_value.
// In this case, we have to find a way to handle the running order of
// set_value is what we want.
paddle
::
framework
::
TensorCopy
(
*
in
,
place
,
out
);
Tensor
slice_tensor
(
in
->
dtype
()),
pad_tensor
(
in
->
dtype
());
slice_tensor
.
mutable_data
<
T
>
(
slice_dims
,
place
);
pad_tensor
.
mutable_data
<
T
>
(
in_dims
,
place
);
auto
pad_e
=
framework
::
EigenTensor
<
T
,
D
>::
From
(
pad_tensor
,
in_dims
);
auto
out_e
=
framework
::
EigenTensor
<
T
,
D
>::
From
(
*
out
);
auto
slice_e
=
framework
::
EigenTensor
<
T
,
D
>::
From
(
slice_tensor
,
slice_dims
);
// Step 1: Set the value of out at `_index` to zero
slice_e
.
device
(
eigen_place
)
=
slice_e
.
constant
(
T
(
0
));
auto
starts_indices
=
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
D
>
();
auto
ends_indices
=
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
D
>
();
auto
strides_indices
=
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
D
>
();
for
(
size_t
i
=
0
;
i
<
D
;
++
i
)
{
starts_indices
[
i
]
=
0
;
ends_indices
[
i
]
=
slice_dims
[
i
];
strides_indices
[
i
]
=
1
;
}
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
int
axis_index
=
axes
[
i
];
starts_indices
[
axis_index
]
=
starts
[
i
];
ends_indices
[
axis_index
]
=
ends
[
i
];
strides_indices
[
axis_index
]
=
steps
[
i
];
if
(
starts
[
i
]
==
ends
[
i
])
{
// slice is empty, data will not be changed
return
;
}
}
out_e
.
stridedSlice
(
starts_indices
,
ends_indices
,
strides_indices
)
.
device
(
eigen_place
)
=
slice_e
;
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value_tensor, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value
// NOTE(liym27): [ Why resize slice_tensor here? ]
// A: When do broadcasting on slice_tensor and value_tensor, the shape of
// slice_tensor should be decreased dims.
// e.g.
// x[:,0] = value_tensor
// x's shape = [3, 4], value_tensor's shape = [3]
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
// shape is [3, 3], which cross the border;
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor
.
Resize
(
slice_dims_for_assign
);
if
(
value_tensor
!=
nullptr
)
{
CheckIsDimsMatch
(
slice_dims_for_assign
,
value_tensor
->
dims
());
// ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
&
slice_tensor
,
value_tensor
,
-
1
,
SubFunctor
<
T
>
(),
&
slice_tensor
);
}
else
{
Tensor
value_t
(
in
->
dtype
());
auto
value_dims
=
phi
::
make_ddim
(
shape
);
CheckIsDimsMatch
(
slice_dims_for_assign
,
value_dims
);
value_t
.
mutable_data
<
T
>
(
value_dims
,
place
);
auto
value_name
=
GetValueName
(
framework
::
TransToProtoVarType
(
in
->
dtype
()));
CopyVecotorToTensor
<
T
>
(
value_name
.
c_str
(),
&
value_t
,
ctx
);
value_t
.
Resize
(
value_dims
);
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
&
slice_tensor
,
&
value_t
,
-
1
,
SubFunctor
<
T
>
(),
&
slice_tensor
);
}
slice_tensor
.
Resize
(
slice_dims
);
// - Step 2.2 Pad slice tensor with 0
pad_e
.
device
(
eigen_place
)
=
pad_e
.
constant
(
T
(
0
));
pad_e
.
stridedSlice
(
starts_indices
,
ends_indices
,
strides_indices
)
.
device
(
eigen_place
)
=
slice_e
;
// Step 3: Set out tensor with value_tensor
out_e
.
device
(
eigen_place
)
=
out_e
-
pad_e
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
SetValueGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SetValueGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
...
paddle/phi/core/kernel_utils.h
浏览文件 @
cd28cddb
...
@@ -252,6 +252,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
...
@@ -252,6 +252,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
float
>&
);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
float
>&
);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
double
>&
);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
double
>&
);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
std
::
string
>&
);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
std
::
string
>&
);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE
(
const
std
::
vector
<
Scalar
>&
);
/* Output Helpers */
/* Output Helpers */
...
...
paddle/phi/kernels/cpu/set_value_kernel.cc
0 → 100644
浏览文件 @
cd28cddb
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/set_value_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_kernel_impl.h"
PD_REGISTER_KERNEL
(
set_value
,
CPU
,
ALL_LAYOUT
,
phi
::
SetValueKernel
,
float
,
double
,
int
,
int64_t
,
bool
)
{}
PD_REGISTER_KERNEL
(
set_value_with_tensor
,
CPU
,
ALL_LAYOUT
,
phi
::
SetTensorValueKernel
,
float
,
double
,
int
,
int64_t
,
bool
)
{}
paddle/phi/kernels/gpu/set_value_kernel.cu
0 → 100644
浏览文件 @
cd28cddb
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/set_value_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_kernel_impl.h"
PD_REGISTER_KERNEL
(
set_value
,
GPU
,
ALL_LAYOUT
,
phi
::
SetValueKernel
,
float
,
double
,
int
,
int64_t
,
bool
)
{}
PD_REGISTER_KERNEL
(
set_value_with_tensor
,
GPU
,
ALL_LAYOUT
,
phi
::
SetTensorValueKernel
,
float
,
double
,
int
,
int64_t
,
bool
)
{}
paddle/phi/kernels/impl/set_value_kernel_impl.h
0 → 100644
浏览文件 @
cd28cddb
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/slice_utils.h"
namespace
phi
{
// check whether the tensor with dimension of second can assign to the
// tensor with dimension of first
inline
void
CheckIsDimsMatch
(
const
DDim
&
first
,
const
DDim
&
second
)
{
int
ignore_axis1
=
0
,
ignore_axis2
=
0
;
for
(;
ignore_axis1
<
first
.
size
();
++
ignore_axis1
)
{
if
(
first
[
ignore_axis1
]
!=
1
)
{
break
;
}
}
for
(;
ignore_axis2
<
second
.
size
();
++
ignore_axis2
)
{
if
(
second
[
ignore_axis2
]
!=
1
)
{
break
;
}
}
if
(
second
.
size
()
==
ignore_axis2
)
{
// second tensor has only one value
return
;
}
if
(
first
.
size
()
-
ignore_axis1
>=
second
.
size
()
-
ignore_axis2
)
{
auto
idx1
=
first
.
size
()
-
1
;
auto
idx2
=
second
.
size
()
-
1
;
bool
is_match
=
true
;
for
(;
idx2
>=
ignore_axis2
;
idx2
--
)
{
if
(
first
[
idx1
--
]
!=
second
[
idx2
]
&&
second
[
idx2
]
!=
1
)
{
is_match
=
false
;
break
;
}
}
if
(
is_match
)
{
return
;
}
}
PADDLE_THROW
(
errors
::
InvalidArgument
(
"The shape of tensor assigned value must match the shape "
"of target shape: %d, but now shape is %d."
,
second
.
to_str
(),
first
.
to_str
()));
}
template
<
typename
T
,
typename
Context
,
size_t
RANK
>
void
SetValueImpl
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
in
,
const
DenseTensor
&
value
,
const
ScalarArray
&
starts
,
const
ScalarArray
&
ends
,
const
ScalarArray
&
steps
,
const
std
::
vector
<
int64_t
>&
axes
,
const
std
::
vector
<
int64_t
>&
decrease_axes
,
const
std
::
vector
<
int64_t
>&
none_axes
,
DenseTensor
*
out
)
{
auto
in_dims
=
in
.
dims
();
std
::
vector
<
int64_t
>
starts_local
=
starts
.
GetData
();
std
::
vector
<
int64_t
>
ends_local
=
ends
.
GetData
();
std
::
vector
<
int64_t
>
steps_local
=
steps
.
GetData
();
paddle
::
operators
::
CheckAndUpdateSliceAttrs
(
in_dims
,
axes
,
&
starts_local
,
&
ends_local
,
&
steps_local
);
auto
slice_dims
=
paddle
::
operators
::
GetSliceDims
(
in_dims
,
axes
,
starts_local
,
ends_local
,
&
steps_local
);
auto
decrease_slice_dims
=
paddle
::
operators
::
GetDecreasedDims
(
slice_dims
,
decrease_axes
);
auto
slice_dims_for_assign
=
decrease_slice_dims
;
if
(
!
none_axes
.
empty
())
{
std
::
vector
<
int64_t
>
slice_dims_with_none
;
size_t
none_axes_cur
=
0
,
decrease_axes_cur
=
0
;
for
(
int
i
=
0
;
i
<
slice_dims
.
size
();
++
i
)
{
while
(
none_axes_cur
<
none_axes
.
size
()
&&
none_axes
[
none_axes_cur
]
<=
i
)
{
slice_dims_with_none
.
push_back
(
1
);
none_axes_cur
++
;
}
if
(
decrease_axes_cur
<
decrease_axes
.
size
()
&&
decrease_axes
[
decrease_axes_cur
]
==
i
)
{
decrease_axes_cur
++
;
}
else
{
slice_dims_with_none
.
push_back
(
slice_dims
[
i
]);
}
}
while
(
none_axes_cur
<
none_axes
.
size
())
{
slice_dims_with_none
.
push_back
(
1
);
none_axes_cur
++
;
}
slice_dims_for_assign
=
phi
::
make_ddim
(
slice_dims_with_none
);
}
auto
place
=
dev_ctx
.
GetPlace
();
auto
&
eigen_place
=
*
dev_ctx
.
eigen_device
();
// Here copy data from input to avoid data loss at PE and Graph level.
// TODO(liym27): Speed up in the future version.
// - Q: Why don't call ShareDataWith to speed up?
// - A: Because it's not supported to ShareDataWith on OP's input and output
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
// - Q: Why don't delete Input, after all, the input and output are the same
// Tensor at program level?
// - A: If deleting Input, the graph will be complex, such as there will
// be two ops points to the output in graph: op1 -> output <- set_value.
// In this case, we have to find a way to handle the running order of
// set_value is what we want.
Copy
(
dev_ctx
,
in
,
place
,
false
,
out
);
DenseTensor
slice_tensor
=
Empty
<
T
>
(
dev_ctx
,
ScalarArray
{
slice_dims
.
Get
(),
slice_dims
.
size
()});
DenseTensor
pad_tensor
=
Empty
<
T
>
(
dev_ctx
,
ScalarArray
{
in_dims
.
Get
(),
in_dims
.
size
()});
auto
pad_e
=
EigenTensor
<
T
,
RANK
>::
From
(
pad_tensor
,
in_dims
);
auto
out_e
=
EigenTensor
<
T
,
RANK
>::
From
(
*
out
);
auto
slice_e
=
EigenTensor
<
T
,
RANK
>::
From
(
slice_tensor
,
slice_dims
);
// Step 1: Set the value of out at `_index` to zero
slice_e
.
device
(
eigen_place
)
=
slice_e
.
constant
(
T
(
0
));
auto
starts_indices
=
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
RANK
>
();
auto
ends_indices
=
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
RANK
>
();
auto
strides_indices
=
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
RANK
>
();
for
(
size_t
i
=
0
;
i
<
RANK
;
++
i
)
{
starts_indices
[
i
]
=
0
;
ends_indices
[
i
]
=
slice_dims
[
i
];
strides_indices
[
i
]
=
1
;
}
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
int
axis_index
=
axes
[
i
];
starts_indices
[
axis_index
]
=
starts_local
[
i
];
ends_indices
[
axis_index
]
=
ends_local
[
i
];
strides_indices
[
axis_index
]
=
steps_local
[
i
];
if
(
starts_local
[
i
]
==
ends_local
[
i
])
{
// slice is empty, data will not be changed
return
;
}
}
out_e
.
stridedSlice
(
starts_indices
,
ends_indices
,
strides_indices
)
.
device
(
eigen_place
)
=
slice_e
;
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value
// NOTE(liym27): [ Why resize slice_tensor here? ]
// A: When do broadcasting on slice_tensor and value, the shape of
// slice_tensor should be decreased dims.
// e.g.
// x[:,0] = value
// x's shape = [3, 4], value's shape = [3]
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
// shape is [3, 3], which cross the border;
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor
.
Resize
(
slice_dims_for_assign
);
CheckIsDimsMatch
(
slice_dims_for_assign
,
value
.
dims
());
// ElementwiseComputeEx can do broadcasting
funcs
::
ElementwiseCompute
<
funcs
::
SubtractFunctor
<
T
>
,
T
>
(
dev_ctx
,
slice_tensor
,
value
,
-
1
,
funcs
::
SubtractFunctor
<
T
>
(),
&
slice_tensor
);
slice_tensor
.
Resize
(
slice_dims
);
// - Step 2.2 Pad slice tensor with 0
pad_e
.
device
(
eigen_place
)
=
pad_e
.
constant
(
T
(
0
));
pad_e
.
stridedSlice
(
starts_indices
,
ends_indices
,
strides_indices
)
.
device
(
eigen_place
)
=
slice_e
;
// Step 3: Set out tensor with value
out_e
.
device
(
eigen_place
)
=
out_e
-
pad_e
;
}
template
<
typename
T
,
typename
Context
>
void
SetTensorValueKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
value
,
const
ScalarArray
&
starts
,
const
ScalarArray
&
ends
,
const
ScalarArray
&
steps
,
const
std
::
vector
<
int64_t
>&
axes
,
const
std
::
vector
<
int64_t
>&
decrease_axes
,
const
std
::
vector
<
int64_t
>&
none_axes
,
DenseTensor
*
out
)
{
const
int
rank
=
x
.
dims
().
size
();
switch
(
rank
)
{
case
1
:
SetValueImpl
<
T
,
Context
,
1
>
(
dev_ctx
,
x
,
value
,
starts
,
ends
,
steps
,
axes
,
decrease_axes
,
none_axes
,
out
);
break
;
case
2
:
SetValueImpl
<
T
,
Context
,
2
>
(
dev_ctx
,
x
,
value
,
starts
,
ends
,
steps
,
axes
,
decrease_axes
,
none_axes
,
out
);
break
;
case
3
:
SetValueImpl
<
T
,
Context
,
3
>
(
dev_ctx
,
x
,
value
,
starts
,
ends
,
steps
,
axes
,
decrease_axes
,
none_axes
,
out
);
break
;
case
4
:
SetValueImpl
<
T
,
Context
,
4
>
(
dev_ctx
,
x
,
value
,
starts
,
ends
,
steps
,
axes
,
decrease_axes
,
none_axes
,
out
);
break
;
case
5
:
SetValueImpl
<
T
,
Context
,
5
>
(
dev_ctx
,
x
,
value
,
starts
,
ends
,
steps
,
axes
,
decrease_axes
,
none_axes
,
out
);
break
;
case
6
:
SetValueImpl
<
T
,
Context
,
6
>
(
dev_ctx
,
x
,
value
,
starts
,
ends
,
steps
,
axes
,
decrease_axes
,
none_axes
,
out
);
break
;
default:
PADDLE_THROW
(
errors
::
InvalidArgument
(
"The rank of input should be less than 7, but received %d."
,
rank
));
}
}
template
<
typename
T
,
typename
Context
>
void
SetValueKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
ScalarArray
&
starts
,
const
ScalarArray
&
ends
,
const
ScalarArray
&
steps
,
const
std
::
vector
<
int64_t
>&
axes
,
const
std
::
vector
<
int64_t
>&
decrease_axes
,
const
std
::
vector
<
int64_t
>&
none_axes
,
const
std
::
vector
<
int64_t
>&
shape
,
const
std
::
vector
<
Scalar
>&
values
,
DenseTensor
*
out
)
{
std
::
vector
<
T
>
assgin_values
;
assgin_values
.
reserve
(
values
.
size
());
for
(
const
auto
&
val
:
values
)
{
assgin_values
.
push_back
(
val
.
to
<
T
>
());
}
DenseTensor
value_tensor
=
Empty
<
T
>
(
dev_ctx
,
shape
);
paddle
::
framework
::
TensorFromVector
(
assgin_values
,
dev_ctx
,
&
value_tensor
);
value_tensor
.
Resize
(
phi
::
make_ddim
(
shape
));
SetTensorValueKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
value_tensor
,
starts
,
ends
,
steps
,
axes
,
decrease_axes
,
none_axes
,
out
);
}
}
// namespace phi
paddle/phi/kernels/set_value_kernel.h
0 → 100644
浏览文件 @
cd28cddb
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SetTensorValueKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
value
,
const
ScalarArray
&
starts
,
const
ScalarArray
&
ends
,
const
ScalarArray
&
steps
,
const
std
::
vector
<
int64_t
>&
axes
,
const
std
::
vector
<
int64_t
>&
decrease_axes
,
const
std
::
vector
<
int64_t
>&
none_axes
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
SetValueKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
ScalarArray
&
starts
,
const
ScalarArray
&
ends
,
const
ScalarArray
&
steps
,
const
std
::
vector
<
int64_t
>&
axes
,
const
std
::
vector
<
int64_t
>&
decrease_axes
,
const
std
::
vector
<
int64_t
>&
none_axes
,
const
std
::
vector
<
int64_t
>&
shape
,
const
std
::
vector
<
Scalar
>&
values
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/ops/compat/set_value_sig.cc
0 → 100644
浏览文件 @
cd28cddb
此差异已折叠。
点击以展开。
paddle/phi/tests/ops/test_op_signature.cc
浏览文件 @
cd28cddb
...
@@ -114,5 +114,375 @@ TEST(ARG_MAP, fill_constant) {
...
@@ -114,5 +114,375 @@ TEST(ARG_MAP, fill_constant) {
ASSERT_EQ
(
signature9
.
name
,
"full_sr"
);
ASSERT_EQ
(
signature9
.
name
,
"full_sr"
);
}
}
TEST
(
ARG_MAP
,
set_value
)
{
TestArgumentMappingContext
arg_case
(
{
"Input"
,
"StartsTensorList"
,
"EndsTensorList"
,
"StepsTensorList"
},
{},
{{
"fp32_values"
,
paddle
::
any
{
std
::
vector
<
float
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case1
(
{
"Input"
,
"StartsTensorList"
,
"EndsTensorList"
,
"StepsTensorList"
},
{},
{{
"fp64_values"
,
paddle
::
any
{
std
::
vector
<
double
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case1
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case2
(
{
"Input"
,
"StartsTensorList"
,
"EndsTensorList"
,
"StepsTensorList"
},
{},
{{
"int32_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case2
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case3
(
{
"Input"
,
"StartsTensorList"
,
"EndsTensorList"
,
"StepsTensorList"
},
{},
{{
"int64_values"
,
paddle
::
any
{
std
::
vector
<
int64_t
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case3
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case4
(
{
"Input"
,
"StartsTensorList"
,
"EndsTensorList"
,
"StepsTensorList"
},
{},
{{
"bool_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case4
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case5
(
{
"Input"
,
"StartsTensorList"
,
"EndsTensorList"
,
"ValueTensor"
},
{},
{},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case5
).
name
,
"set_value_with_tensor"
);
TestArgumentMappingContext
arg_case6
(
{
"Input"
,
"StartsTensorList"
,
"EndsTensorList"
},
{},
{{
"fp64_values"
,
paddle
::
any
{
std
::
vector
<
double
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case6
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case7
(
{
"Input"
,
"StartsTensorList"
,
"EndsTensorList"
},
{},
{{
"int32_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case7
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case8
(
{
"Input"
,
"StartsTensorList"
,
"EndsTensorList"
},
{},
{{
"int64_values"
,
paddle
::
any
{
std
::
vector
<
int64_t
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case8
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case9
(
{
"Input"
,
"StartsTensorList"
,
"EndsTensorList"
},
{},
{{
"bool_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case9
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case10
(
{
"Input"
,
"StartsTensorList"
,
"StepsTensorList"
,
"ValueTensor"
},
{},
{},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case10
).
name
,
"set_value_with_tensor"
);
TestArgumentMappingContext
arg_case11
(
{
"Input"
,
"StartsTensorList"
,
"StepsTensorList"
},
{},
{{
"fp64_values"
,
paddle
::
any
{
std
::
vector
<
double
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case11
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case12
(
{
"Input"
,
"StartsTensorList"
,
"StepsTensorList"
},
{},
{{
"int32_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case12
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case13
(
{
"Input"
,
"StartsTensorList"
,
"StepsTensorList"
},
{},
{{
"int64_values"
,
paddle
::
any
{
std
::
vector
<
int64_t
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case13
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case14
(
{
"Input"
,
"StartsTensorList"
,
"StepsTensorList"
},
{},
{{
"bool_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case14
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case15
(
{
"Input"
,
"StartsTensorList"
,
"ValueTensor"
},
{},
{},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case15
).
name
,
"set_value_with_tensor"
);
TestArgumentMappingContext
arg_case16
(
{
"Input"
,
"StartsTensorList"
,
"StepsTensorList"
},
{},
{{
"fp32_values"
,
paddle
::
any
{
std
::
vector
<
float
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case16
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case17
(
{
"Input"
,
"StartsTensorList"
,
"StepsTensorList"
},
{},
{{
"fp64_values"
,
paddle
::
any
{
std
::
vector
<
double
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case17
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case18
(
{
"Input"
,
"StartsTensorList"
,
"StepsTensorList"
},
{},
{{
"int32_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case18
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case19
(
{
"Input"
,
"StartsTensorList"
,
"StepsTensorList"
},
{},
{{
"int64_values"
,
paddle
::
any
{
std
::
vector
<
int64_t
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case19
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case20
(
{
"Input"
,
"StartsTensorList"
,
"StepsTensorList"
},
{},
{{
"bool_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case20
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case21
(
{
"Input"
,
"EndsTensorList"
,
"StepsTensorList"
,
"ValueTensor"
},
{},
{},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case21
).
name
,
"set_value_with_tensor"
);
TestArgumentMappingContext
arg_case22
(
{
"Input"
,
"EndsTensorList"
,
"StepsTensorList"
},
{},
{{
"fp64_values"
,
paddle
::
any
{
std
::
vector
<
double
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case22
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case23
(
{
"Input"
,
"EndsTensorList"
,
"StepsTensorList"
},
{},
{{
"int32_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case23
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case24
(
{
"Input"
,
"EndsTensorList"
,
"StepsTensorList"
},
{},
{{
"int64_values"
,
paddle
::
any
{
std
::
vector
<
int64_t
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case24
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case25
(
{
"Input"
,
"EndsTensorList"
,
"StepsTensorList"
},
{},
{{
"bool_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case25
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case26
(
{
"Input"
,
"EndsTensorList"
,
"ValueTensor"
},
{},
{},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case26
).
name
,
"set_value_with_tensor"
);
TestArgumentMappingContext
arg_case27
(
{
"Input"
,
"EndsTensorList"
},
{},
{{
"fp32_values"
,
paddle
::
any
{
std
::
vector
<
float
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case27
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case28
(
{
"Input"
,
"EndsTensorList"
},
{},
{{
"fp64_values"
,
paddle
::
any
{
std
::
vector
<
double
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case28
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case29
(
{
"Input"
,
"EndsTensorList"
},
{},
{{
"int32_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case29
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case30
(
{
"Input"
,
"EndsTensorList"
},
{},
{{
"int64_values"
,
paddle
::
any
{
std
::
vector
<
int64_t
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case30
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case31
(
{
"Input"
,
"EndsTensorList"
},
{},
{{
"bool_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case31
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case32
(
{
"Input"
,
"StepsTensorList"
,
"ValueTensor"
},
{},
{},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case32
).
name
,
"set_value_with_tensor"
);
TestArgumentMappingContext
arg_case33
(
{
"Input"
,
"StepsTensorList"
},
{},
{{
"fp32_values"
,
paddle
::
any
{
std
::
vector
<
float
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case33
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case34
(
{
"Input"
,
"StepsTensorList"
},
{},
{{
"fp64_values"
,
paddle
::
any
{
std
::
vector
<
double
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case34
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case35
(
{
"Input"
,
"StepsTensorList"
},
{},
{{
"int32_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case35
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case36
(
{
"Input"
,
"StepsTensorList"
},
{},
{{
"int64_values"
,
paddle
::
any
{
std
::
vector
<
int64_t
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case36
).
name
,
"set_value"
);
TestArgumentMappingContext
arg_case37
(
{
"Input"
,
"StepsTensorList"
},
{},
{{
"bool_values"
,
paddle
::
any
{
std
::
vector
<
int
>
{
1
}}}},
{
"Out"
},
{});
ASSERT_EQ
(
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"set_value"
)(
arg_case37
).
name
,
"set_value"
);
}
}
// namespace tests
}
// namespace tests
}
// namespace phi
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录