From e5c61b151e6e0af60dcc8cda0319b77544e88296 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 19 Apr 2022 23:13:46 +0800 Subject: [PATCH] polish tensor api details (#41971) --- paddle/phi/api/include/tensor.h | 6 +++--- paddle/phi/api/lib/tensor.cc | 4 ++-- paddle/phi/api/lib/tensor_method.cc | 4 ++-- python/paddle/utils/code_gen/api_base.py | 10 +++++----- python/paddle/utils/code_gen/strings_api_gen.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/paddle/phi/api/include/tensor.h b/paddle/phi/api/include/tensor.h index e4a97e2c16..2b0aea9e1e 100644 --- a/paddle/phi/api/include/tensor.h +++ b/paddle/phi/api/include/tensor.h @@ -166,7 +166,7 @@ class PADDLE_API Tensor final { * * @return phi::DDim */ - phi::DDim dims() const; + const phi::DDim& dims() const; /** * @brief Return the shape (dimensions) of Tensor. @@ -260,7 +260,7 @@ class PADDLE_API Tensor final { * * @return Place */ - Place place() const; + const Place& place() const; /** * @brief Determine whether the tensor device is CPU @@ -421,7 +421,7 @@ class PADDLE_API Tensor final { * @param blocking, Should we copy this in sync way. * @return Tensor */ - Tensor copy_to(Place place, bool blocking) const; + Tensor copy_to(const Place& place, bool blocking) const; /** * @brief Transfer the source Tensor to current Tensor. diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index f1aa48a2a4..be0a937c91 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -110,7 +110,7 @@ int64_t Tensor::numel() const { return impl_->numel(); } int64_t Tensor::size() const { return impl_->numel(); } -phi::DDim Tensor::dims() const { return impl_->dims(); } +const phi::DDim &Tensor::dims() const { return impl_->dims(); } std::vector Tensor::shape() const { auto dims = impl_->dims(); @@ -158,7 +158,7 @@ bool Tensor::is_string_tensor() const { } /* Part 3: Device and Backend methods */ -Place Tensor::place() const { +const Place &Tensor::place() const { PADDLE_ENFORCE_NOT_NULL( impl_, phi::errors::PermissionDenied( diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index 463b72d0db..5285392b4a 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -27,13 +27,13 @@ namespace paddle { namespace experimental { // declare cast api Tensor cast(const Tensor &x, DataType out_dtype); -Tensor copy_to(const Tensor &x, Place place, bool blocking); +Tensor copy_to(const Tensor &x, const Place &place, bool blocking); Tensor Tensor::cast(DataType target_type) const { return experimental::cast(*this, target_type); } -Tensor Tensor::copy_to(Place place, bool blocking) const { +Tensor Tensor::copy_to(const Place &place, bool blocking) const { return experimental::copy_to(*this, place, blocking); } diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index a6bd0a10cb..378ead7ff2 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -105,7 +105,7 @@ class BaseAPI(object): 'double': 'double', 'bool': 'bool', 'str': 'const std::string&', - 'Place': 'Place', + 'Place': 'const Place&', 'DataLayout': 'DataLayout', 'DataType': 'DataType', 'int64_t[]': 'const std::vector&', @@ -120,7 +120,7 @@ class BaseAPI(object): 'float': 'paddle::optional', 'double': 'paddle::optional', 'bool': 'paddle::optional', - 'Place': 'paddle::optional', + 'Place': 'paddle::optional', 'DataLayout': 'paddle::optional', 'DataType': 'paddle::optional' } @@ -328,7 +328,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self assert len( vars_list ) == 2, f"{self.api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}." - assert (vars_list[0].strip() in self.attrs['names']) and (self.attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \ + assert (vars_list[0].strip() in self.attrs['names']) and (self.attrs['attr_info'][vars_list[0].strip()][0] == 'const Place&'), \ f"{self.api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type." backend_select_code = f""" kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()}); @@ -360,7 +360,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self attr_layout_count = 0 attr_data_type_count = 0 for attr_name in attrs['names']: - if attrs['attr_info'][attr_name][0] == 'Place': + if attrs['attr_info'][attr_name][0] == 'const Place&': assert kernel['backend'] is not None, \ f"{api} api: When there is a parameter with 'Place' type in attributes, you must set backend of kernel manually." attr_backend_count = attr_backend_count + 1 @@ -420,7 +420,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self if len(input_names) == 0: assert attr_backend_count > 0 and attr_data_type_count > 0, \ - f"{api} api: When there is no input tensor, the args must have 'Backend' and 'DataType'." + f"{api} api: When there is no input tensor, the args must have 'Place' and 'DataType'." kernel_select_args = "" for input_name in input_names: diff --git a/python/paddle/utils/code_gen/strings_api_gen.py b/python/paddle/utils/code_gen/strings_api_gen.py index d7117e9d54..061ea6c3ce 100644 --- a/python/paddle/utils/code_gen/strings_api_gen.py +++ b/python/paddle/utils/code_gen/strings_api_gen.py @@ -225,7 +225,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s assert len( vars_list ) == 2, f"{api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}." - assert (vars_list[0].strip() in attrs['names']) and (attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \ + assert (vars_list[0].strip() in attrs['names']) and (attrs['attr_info'][vars_list[0].strip()][0] == 'const Place&'), \ f"{api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type." kernel_select_code = kernel_select_code + f""" kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()}); -- GitLab