Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
75923a32
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
75923a32
编写于
1月 28, 2022
作者:
C
Chen Weihang
提交者:
GitHub
1月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTen] Update all forward argument maping fns (#39252)
* update forward argument mapping * fix compile failed * fix test failed
上级
9a001c09
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
394 addition
and
128 deletion
+394
-128
paddle/fluid/framework/infershape_utils.cc
paddle/fluid/framework/infershape_utils.cc
+10
-0
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+10
-2
paddle/fluid/operators/cast_op.cc
paddle/fluid/operators/cast_op.cc
+0
-5
paddle/fluid/operators/concat_op.cc
paddle/fluid/operators/concat_op.cc
+0
-9
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+0
-44
paddle/fluid/operators/empty_op.cc
paddle/fluid/operators/empty_op.cc
+0
-14
paddle/fluid/operators/fill_any_like_op.cc
paddle/fluid/operators/fill_any_like_op.cc
+0
-5
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+0
-23
paddle/fluid/operators/flatten_op.cc
paddle/fluid/operators/flatten_op.cc
+0
-12
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+0
-14
paddle/pten/core/compat/arg_map_context.h
paddle/pten/core/compat/arg_map_context.h
+3
-0
paddle/pten/ops/compat/cast_sig.cc
paddle/pten/ops/compat/cast_sig.cc
+25
-0
paddle/pten/ops/compat/concat_sig.cc
paddle/pten/ops/compat/concat_sig.cc
+28
-0
paddle/pten/ops/compat/elementwise_sig.cc
paddle/pten/ops/compat/elementwise_sig.cc
+76
-0
paddle/pten/ops/compat/empty_sig.cc
paddle/pten/ops/compat/empty_sig.cc
+31
-0
paddle/pten/ops/compat/fill_any_like_sig.cc
paddle/pten/ops/compat/fill_any_like_sig.cc
+26
-0
paddle/pten/ops/compat/fill_constant_sig.cc
paddle/pten/ops/compat/fill_constant_sig.cc
+71
-0
paddle/pten/ops/compat/flatten_sig.cc
paddle/pten/ops/compat/flatten_sig.cc
+34
-0
paddle/pten/ops/compat/reduce_sig.cc
paddle/pten/ops/compat/reduce_sig.cc
+49
-0
paddle/pten/ops/compat/reshape_sig.cc
paddle/pten/ops/compat/reshape_sig.cc
+31
-0
未找到文件。
paddle/fluid/framework/infershape_utils.cc
浏览文件 @
75923a32
...
...
@@ -64,6 +64,16 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext {
return
var_types
[
0
]
==
proto
::
VarType
::
SELECTED_ROWS
;
}
bool
IsDenseTensorOutput
(
const
std
::
string
&
name
)
const
override
{
auto
var_types
=
ctx_
.
GetOutputsVarType
(
name
);
return
var_types
[
0
]
==
proto
::
VarType
::
LOD_TENSOR
;
}
bool
IsSelectedRowsOutput
(
const
std
::
string
&
name
)
const
override
{
auto
var_types
=
ctx_
.
GetOutputsVarType
(
name
);
return
var_types
[
0
]
==
proto
::
VarType
::
SELECTED_ROWS
;
}
private:
const
InferShapeContext
&
ctx_
;
};
...
...
paddle/fluid/framework/operator.h
浏览文件 @
75923a32
...
...
@@ -461,11 +461,11 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
}
size_t
InputSize
(
const
std
::
string
&
name
)
const
override
{
return
ctx_
.
InputSize
(
name
);
return
ctx_
.
MultiInputVar
(
name
).
size
(
);
}
size_t
OutputSize
(
const
std
::
string
&
name
)
const
override
{
return
ctx_
.
OutputSize
(
name
);
return
ctx_
.
MultiOutputVar
(
name
).
size
(
);
}
bool
IsDenseTensorInput
(
const
std
::
string
&
name
)
const
override
{
...
...
@@ -476,6 +476,14 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
return
ctx_
.
InputVar
(
name
)
->
IsType
<
pten
::
SelectedRows
>
();
}
bool
IsDenseTensorOutput
(
const
std
::
string
&
name
)
const
override
{
return
ctx_
.
OutputVar
(
name
)
->
IsType
<
framework
::
LoDTensor
>
();
}
bool
IsSelectedRowsOutput
(
const
std
::
string
&
name
)
const
override
{
return
ctx_
.
OutputVar
(
name
)
->
IsType
<
pten
::
SelectedRows
>
();
}
private:
const
ExecutionContext
&
ctx_
;
};
...
...
paddle/fluid/operators/cast_op.cc
浏览文件 @
75923a32
...
...
@@ -121,11 +121,6 @@ class CastOp : public framework::OperatorWithKernel {
#endif
return
framework
::
OpKernelType
(
tensor
->
type
(),
tensor_place
);
}
framework
::
KernelSignature
GetExpectedPtenKernelArgs
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
KernelSignature
(
"cast"
,
{
"X"
},
{
"out_dtype"
},
{
"Out"
});
}
};
}
// namespace operators
...
...
paddle/fluid/operators/concat_op.cc
浏览文件 @
75923a32
...
...
@@ -104,15 +104,6 @@ class ConcatOp : public framework::OperatorWithKernel {
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
framework
::
KernelSignature
GetExpectedPtenKernelArgs
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
if
(
ctx
.
HasInput
(
"AxisTensor"
))
{
return
framework
::
KernelSignature
(
"concat"
,
{
"X"
},
{
"AxisTensor"
},
{
"Out"
});
}
return
framework
::
KernelSignature
(
"concat"
,
{
"X"
},
{
"axis"
},
{
"Out"
});
}
};
class
ConcatOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
75923a32
...
...
@@ -137,50 +137,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
tensor
.
place
(),
tensor
.
layout
());
}
}
framework
::
KernelSignature
GetExpectedPtenKernelArgs
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
if
(
Type
()
==
"elementwise_add"
)
{
if
(
ctx
.
InputVar
(
"X"
)
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
axis
==
-
1
)
{
return
framework
::
KernelSignature
(
"add"
,
{
"X"
,
"Y"
},
{},
{
"Out"
});
}
return
framework
::
KernelSignature
(
"add_raw"
,
{
"X"
,
"Y"
},
{
"axis"
},
{
"Out"
});
}
}
if
(
Type
()
==
"elementwise_sub"
)
{
if
(
ctx
.
InputVar
(
"X"
)
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
axis
==
-
1
)
{
return
framework
::
KernelSignature
(
"subtract"
,
{
"X"
,
"Y"
},
{},
{
"Out"
});
}
return
framework
::
KernelSignature
(
"subtract_raw"
,
{
"X"
,
"Y"
},
{
"axis"
},
{
"Out"
});
}
}
if
(
Type
()
==
"elementwise_div"
)
{
if
(
ctx
.
InputVar
(
"X"
)
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
axis
==
-
1
)
{
return
framework
::
KernelSignature
(
"divide"
,
{
"X"
,
"Y"
},
{},
{
"Out"
});
}
return
framework
::
KernelSignature
(
"divide_raw"
,
{
"X"
,
"Y"
},
{
"axis"
},
{
"Out"
});
}
}
if
(
Type
()
==
"elementwise_mul"
)
{
if
(
ctx
.
InputVar
(
"X"
)
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
axis
==
-
1
)
{
return
framework
::
KernelSignature
(
"multiply"
,
{
"X"
,
"Y"
},
{},
{
"Out"
});
}
return
framework
::
KernelSignature
(
"multiply_raw"
,
{
"X"
,
"Y"
},
{
"axis"
},
{
"Out"
});
}
}
return
framework
::
KernelSignature
(
"None"
,
{
"X"
},
{},
{
"Out"
});
}
};
class
ElementwiseOpInferVarType
...
...
paddle/fluid/operators/empty_op.cc
浏览文件 @
75923a32
...
...
@@ -109,20 +109,6 @@ class EmptyOp : public framework::OperatorWithKernel {
framework
::
proto
::
VarType
::
Type
(
context
.
Attr
<
int
>
(
"dtype"
)),
context
.
GetPlace
());
}
framework
::
KernelSignature
GetExpectedPtenKernelArgs
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
std
::
string
shape
;
if
(
ctx
.
HasInput
(
"ShapeTensor"
))
{
shape
=
"ShapeTensor"
;
}
else
if
(
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"ShapeTensorList"
).
size
())
{
shape
=
"ShapeTensorList"
;
}
else
{
shape
=
"shape"
;
}
return
framework
::
KernelSignature
(
"empty"
,
{},
{
shape
},
{
"Out"
});
}
};
class
EmptyOpVarTypeInference
:
public
framework
::
VarTypeInference
{
...
...
paddle/fluid/operators/fill_any_like_op.cc
浏览文件 @
75923a32
...
...
@@ -47,11 +47,6 @@ class FillAnyLikeOp : public framework::OperatorWithKernel {
expected_kernel_type
.
place_
,
tensor
.
layout
());
}
framework
::
KernelSignature
GetExpectedPtenKernelArgs
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
KernelSignature
(
"full_like"
,
{},
{
"value"
},
{
"Out"
});
}
};
class
FillAnyLikeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
75923a32
...
...
@@ -99,29 +99,6 @@ class FillConstantOp : public framework::OperatorWithKernel {
return
kt
;
}
framework
::
KernelSignature
GetExpectedPtenKernelArgs
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
std
::
string
shape
;
if
(
ctx
.
HasInput
(
"ShapeTensor"
))
{
shape
=
"ShapeTensor"
;
}
else
if
(
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"ShapeTensorList"
).
size
())
{
shape
=
"ShapeTensorList"
;
}
else
{
shape
=
"shape"
;
}
std
::
string
value
;
if
(
ctx
.
HasInput
(
"ValueTensor"
))
{
value
=
"ValueTensor"
;
}
else
{
const
auto
&
str_value
=
ctx
.
Attr
<
std
::
string
>
(
"str_value"
);
value
=
str_value
.
empty
()
?
"value"
:
"str_value"
;
}
if
(
!
ctx
.
OutputVar
(
"Out"
)
->
IsType
<
pten
::
SelectedRows
>
())
{
return
framework
::
KernelSignature
(
"full"
,
{},
{
shape
,
value
},
{
"Out"
});
}
return
framework
::
KernelSignature
(
"fill_constant.unregistered"
,
{},
{},
{});
}
};
class
FillConstantOpVarTypeInference
:
public
framework
::
VarTypeInference
{
...
...
paddle/fluid/operators/flatten_op.cc
浏览文件 @
75923a32
...
...
@@ -333,18 +333,6 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
return
out_shape
;
}
framework
::
KernelSignature
GetExpectedPtenKernelArgs
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
if
(
ctx
.
HasOutput
(
"XShape"
))
{
return
framework
::
KernelSignature
(
"flatten_with_xshape"
,
{
"X"
},
{
"start_axis"
,
"stop_axis"
},
{
"Out"
,
"XShape"
});
}
else
{
return
framework
::
KernelSignature
(
"flatten"
,
{
"X"
},
{
"start_axis"
,
"stop_axis"
},
{
"Out"
});
}
}
};
class
FlattenContiguousRangeOpMaker
:
public
FlattenOpMaker
{
...
...
paddle/fluid/operators/reshape_op.cc
浏览文件 @
75923a32
...
...
@@ -485,20 +485,6 @@ class Reshape2Op : public ReshapeOp {
ReshapeOp
::
InferShape
(
ctx
);
}
framework
::
KernelSignature
GetExpectedPtenKernelArgs
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
std
::
string
shape
;
auto
multi_inputs
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"ShapeTensor"
);
if
(
multi_inputs
.
size
()
>
0
)
{
shape
=
"ShapeTensor"
;
}
else
if
(
ctx
.
HasInput
(
"Shape"
))
{
shape
=
"Shape"
;
}
else
{
shape
=
"shape"
;
}
return
framework
::
KernelSignature
(
"reshape"
,
{
"X"
},
{
shape
},
{
"Out"
});
}
};
class
Reshape2OpMaker
:
public
ReshapeOpMaker
{
...
...
paddle/pten/core/compat/arg_map_context.h
浏览文件 @
75923a32
...
...
@@ -75,6 +75,9 @@ class ArgumentMappingContext {
virtual
bool
IsDenseTensorInput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
IsSelectedRowsInput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
IsDenseTensorOutput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
IsSelectedRowsOutput
(
const
std
::
string
&
name
)
const
=
0
;
};
}
// namespace pten
paddle/pten/ops/compat/cast_sig.cc
0 → 100644
浏览文件 @
75923a32
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
KernelSignature
CastOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"cast"
,
{
"X"
},
{
"out_dtype"
},
{
"Out"
});
}
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
cast
,
pten
::
CastOpArgumentMapping
);
paddle/pten/ops/compat/concat_sig.cc
0 → 100644
浏览文件 @
75923a32
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
KernelSignature
ConcatOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
HasInput
(
"AxisTensor"
))
{
return
KernelSignature
(
"concat"
,
{
"X"
},
{
"AxisTensor"
},
{
"Out"
});
}
return
KernelSignature
(
"concat"
,
{
"X"
},
{
"axis"
},
{
"Out"
});
}
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
concat
,
pten
::
ConcatOpArgumentMapping
);
paddle/pten/ops/compat/elementwise_sig.cc
0 → 100644
浏览文件 @
75923a32
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
KernelSignature
ElementwiseAddOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
int
axis
=
paddle
::
any_cast
<
int
>
(
ctx
.
Attr
(
"axis"
));
if
(
ctx
.
IsDenseTensorInput
(
"X"
))
{
if
(
axis
==
-
1
)
{
return
KernelSignature
(
"add"
,
{
"X"
,
"Y"
},
{},
{
"Out"
});
}
return
KernelSignature
(
"add_raw"
,
{
"X"
,
"Y"
},
{
"axis"
},
{
"Out"
});
}
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
}
KernelSignature
ElementwiseSubOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
int
axis
=
paddle
::
any_cast
<
int
>
(
ctx
.
Attr
(
"axis"
));
if
(
ctx
.
IsDenseTensorInput
(
"X"
))
{
if
(
axis
==
-
1
)
{
return
KernelSignature
(
"subtract"
,
{
"X"
,
"Y"
},
{},
{
"Out"
});
}
return
KernelSignature
(
"subtract_raw"
,
{
"X"
,
"Y"
},
{
"axis"
},
{
"Out"
});
}
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
}
KernelSignature
ElementwiseMulOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
int
axis
=
paddle
::
any_cast
<
int
>
(
ctx
.
Attr
(
"axis"
));
if
(
ctx
.
IsDenseTensorInput
(
"X"
))
{
if
(
axis
==
-
1
)
{
return
KernelSignature
(
"multiply"
,
{
"X"
,
"Y"
},
{},
{
"Out"
});
}
return
KernelSignature
(
"multiply_raw"
,
{
"X"
,
"Y"
},
{
"axis"
},
{
"Out"
});
}
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
}
KernelSignature
ElementwiseDivOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
int
axis
=
paddle
::
any_cast
<
int
>
(
ctx
.
Attr
(
"axis"
));
if
(
ctx
.
IsDenseTensorInput
(
"X"
))
{
if
(
axis
==
-
1
)
{
return
KernelSignature
(
"divide"
,
{
"X"
,
"Y"
},
{},
{
"Out"
});
}
return
KernelSignature
(
"divide_raw"
,
{
"X"
,
"Y"
},
{
"axis"
},
{
"Out"
});
}
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
}
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
elementwise_add
,
pten
::
ElementwiseAddOpArgumentMapping
);
PT_REGISTER_ARG_MAPPING_FN
(
elementwise_sub
,
pten
::
ElementwiseSubOpArgumentMapping
);
PT_REGISTER_ARG_MAPPING_FN
(
elementwise_mul
,
pten
::
ElementwiseMulOpArgumentMapping
);
PT_REGISTER_ARG_MAPPING_FN
(
elementwise_div
,
pten
::
ElementwiseDivOpArgumentMapping
);
paddle/pten/ops/compat/empty_sig.cc
0 → 100644
浏览文件 @
75923a32
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
KernelSignature
EmptyOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
HasInput
(
"ShapeTensor"
))
{
return
KernelSignature
(
"empty"
,
{},
{
"ShapeTensor"
},
{
"Out"
});
}
else
if
(
ctx
.
InputSize
(
"ShapeTensorList"
)
>
0
)
{
return
KernelSignature
(
"empty"
,
{},
{
"ShapeTensorList"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"empty"
,
{},
{
"shape"
},
{
"Out"
});
}
}
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
empty
,
pten
::
EmptyOpArgumentMapping
);
paddle/pten/ops/compat/fill_any_like_sig.cc
0 → 100644
浏览文件 @
75923a32
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
KernelSignature
FillAnyLikeOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"full_like"
,
{},
{
"value"
},
{
"Out"
});
}
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
fill_any_like
,
pten
::
FillAnyLikeOpArgumentMapping
);
paddle/pten/ops/compat/fill_constant_sig.cc
0 → 100644
浏览文件 @
75923a32
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
// we have to return every specific KernelSignature for infrt now
KernelSignature
FillConstantOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
IsDenseTensorOutput
(
"Out"
))
{
if
(
ctx
.
HasInput
(
"ShapeTensor"
))
{
if
(
ctx
.
HasInput
(
"ValueTensor"
))
{
return
KernelSignature
(
"full"
,
{},
{
"ShapeTensor"
,
"ValueTensor"
},
{
"Out"
});
}
else
{
const
auto
&
str_value
=
paddle
::
any_cast
<
std
::
string
>
(
ctx
.
Attr
(
"str_value"
));
if
(
str_value
.
empty
())
{
return
KernelSignature
(
"full"
,
{},
{
"ShapeTensor"
,
"value"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"full"
,
{},
{
"ShapeTensor"
,
"str_value"
},
{
"Out"
});
}
}
}
else
if
(
ctx
.
InputSize
(
"ShapeTensorList"
)
>
0
)
{
if
(
ctx
.
HasInput
(
"ValueTensor"
))
{
return
KernelSignature
(
"full"
,
{},
{
"ShapeTensorList"
,
"ValueTensor"
},
{
"Out"
});
}
else
{
const
auto
&
str_value
=
paddle
::
any_cast
<
std
::
string
>
(
ctx
.
Attr
(
"str_value"
));
if
(
str_value
.
empty
())
{
return
KernelSignature
(
"full"
,
{},
{
"ShapeTensorList"
,
"value"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"full"
,
{},
{
"ShapeTensorList"
,
"str_value"
},
{
"Out"
});
}
}
}
else
{
if
(
ctx
.
HasInput
(
"ValueTensor"
))
{
return
KernelSignature
(
"full"
,
{},
{
"shape"
,
"ValueTensor"
},
{
"Out"
});
}
else
{
const
auto
&
str_value
=
paddle
::
any_cast
<
std
::
string
>
(
ctx
.
Attr
(
"str_value"
));
if
(
str_value
.
empty
())
{
return
KernelSignature
(
"full"
,
{},
{
"shape"
,
"value"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"full"
,
{},
{
"shape"
,
"str_value"
},
{
"Out"
});
}
}
}
}
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
}
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
fill_constant
,
pten
::
FillConstantOpArgumentMapping
);
paddle/pten/ops/compat/flatten_sig.cc
0 → 100644
浏览文件 @
75923a32
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
KernelSignature
FlattenOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
HasOutput
(
"XShape"
))
{
return
KernelSignature
(
"flatten_with_xshape"
,
{
"X"
},
{
"start_axis"
,
"stop_axis"
},
{
"Out"
,
"XShape"
});
}
else
{
return
KernelSignature
(
"flatten"
,
{
"X"
},
{
"start_axis"
,
"stop_axis"
},
{
"Out"
});
}
}
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
flatten_contiguous_range
,
pten
::
FlattenOpArgumentMapping
);
paddle/pten/ops/compat/reduce_sig.cc
0 → 100644
浏览文件 @
75923a32
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
KernelSignature
ReduceSumOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
bool
reduce_all
=
paddle
::
any_cast
<
bool
>
(
ctx
.
Attr
(
"reduce_all"
));
if
(
ctx
.
IsDenseTensorInput
(
"X"
))
{
if
(
!
reduce_all
)
{
return
KernelSignature
(
"sum"
,
{
"X"
},
{
"dim"
,
"keep_dim"
,
"out_dtype"
},
{
"Out"
});
}
return
KernelSignature
(
"sum_raw"
,
{
"X"
},
{
"dim"
,
"keep_dim"
,
"reduce_all"
,
"out_dtype"
},
{
"Out"
});
}
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
}
KernelSignature
ReduceMeanOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
bool
reduce_all
=
paddle
::
any_cast
<
bool
>
(
ctx
.
Attr
(
"reduce_all"
));
if
(
ctx
.
IsDenseTensorInput
(
"X"
))
{
if
(
!
reduce_all
)
{
return
KernelSignature
(
"mean"
,
{
"X"
},
{
"dim"
,
"keep_dim"
},
{
"Out"
});
}
return
KernelSignature
(
"mean_raw"
,
{
"X"
},
{
"dim"
,
"keep_dim"
,
"reduce_all"
},
{
"Out"
});
}
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
}
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
reduce_sum
,
pten
::
ReduceSumOpArgumentMapping
);
PT_REGISTER_ARG_MAPPING_FN
(
reduce_mean
,
pten
::
ReduceMeanOpArgumentMapping
);
paddle/pten/ops/compat/reshape_sig.cc
0 → 100644
浏览文件 @
75923a32
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
KernelSignature
ReshapeOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
InputSize
(
"ShapeTensor"
)
>
0
)
{
return
KernelSignature
(
"reshape"
,
{
"X"
},
{
"ShapeTensor"
},
{
"Out"
});
}
else
if
(
ctx
.
HasInput
(
"Shape"
))
{
return
KernelSignature
(
"reshape"
,
{
"X"
},
{
"Shape"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"reshape"
,
{
"X"
},
{
"shape"
},
{
"Out"
});
}
}
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
reshape2
,
pten
::
ReshapeOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录