未验证 提交 038ca68d 编写于 作者: C Chen Weihang 提交者: GitHub

revert attr type change (#38129)

上级 e5a838f8
...@@ -186,6 +186,14 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -186,6 +186,14 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_ComputeCallHelper(const float&); PD_SPECIALIZE_ComputeCallHelper(const float&);
PD_SPECIALIZE_ComputeCallHelper(const int64_t&); PD_SPECIALIZE_ComputeCallHelper(const int64_t&);
// NOTE(chenweihang): Used to be compatible with the 2.1 released
// interface, but not recommended
PD_SPECIALIZE_ComputeCallHelper(std::string);
PD_SPECIALIZE_ComputeCallHelper(std::vector<int>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<float>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>);
// end: base template // end: base template
template <typename T> template <typename T>
struct ComputeCallHelper<TypeTag<T>> { struct ComputeCallHelper<TypeTag<T>> {
...@@ -329,6 +337,13 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -329,6 +337,13 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const float&); PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const float&);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int64_t&); PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(const int64_t&);
// NOTE(chenweihang): Used to be compatible with the 2.1 released
// interface, but not recommended
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(std::string);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(std::vector<int>);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(std::vector<float>);
PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR(std::vector<std::string>);
// end: base template // end: base template
template <typename T> template <typename T>
struct InferShapeCallHelper<TypeTag<T>> { struct InferShapeCallHelper<TypeTag<T>> {
......
...@@ -127,11 +127,11 @@ std::vector<paddle::Tensor> AttrTestForward( ...@@ -127,11 +127,11 @@ std::vector<paddle::Tensor> AttrTestForward(
int int_attr, int int_attr,
float float_attr, float float_attr,
int64_t int64_attr, int64_t int64_attr,
const std::string& str_attr, std::string str_attr,
const std::vector<int>& int_vec_attr, std::vector<int> int_vec_attr,
const std::vector<float>& float_vec_attr, std::vector<float> float_vec_attr,
const std::vector<int64_t>& int64_vec_attr, std::vector<int64_t> int64_vec_attr,
const std::vector<std::string>& str_vec_attr) { std::vector<std::string> str_vec_attr) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_TYPES(
...@@ -160,10 +160,10 @@ std::vector<std::vector<int64_t>> AttrTestInferShape( ...@@ -160,10 +160,10 @@ std::vector<std::vector<int64_t>> AttrTestInferShape(
int int_attr, int int_attr,
float float_attr, float float_attr,
int64_t int64_attr, int64_t int64_attr,
const std::string& str_attr, std::string str_attr,
const std::vector<int>& int_vec_attr, std::vector<int> int_vec_attr,
const std::vector<float>& float_vec_attr, std::vector<float> float_vec_attr,
const std::vector<std::string>& str_vec_attr) { std::vector<std::string> str_vec_attr) {
return {x_shape}; return {x_shape};
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册