未验证 提交 e5c61b15 编写于 作者: C Chen Weihang 提交者: GitHub

polish tensor api details (#41971)

上级 8113c913
......@@ -166,7 +166,7 @@ class PADDLE_API Tensor final {
*
* @return phi::DDim
*/
phi::DDim dims() const;
const phi::DDim& dims() const;
/**
* @brief Return the shape (dimensions) of Tensor.
......@@ -260,7 +260,7 @@ class PADDLE_API Tensor final {
*
* @return Place
*/
Place place() const;
const Place& place() const;
/**
* @brief Determine whether the tensor device is CPU
......@@ -421,7 +421,7 @@ class PADDLE_API Tensor final {
* @param blocking, Should we copy this in sync way.
* @return Tensor
*/
Tensor copy_to(Place place, bool blocking) const;
Tensor copy_to(const Place& place, bool blocking) const;
/**
* @brief Transfer the source Tensor to current Tensor.
......
......@@ -110,7 +110,7 @@ int64_t Tensor::numel() const { return impl_->numel(); }
int64_t Tensor::size() const { return impl_->numel(); }
phi::DDim Tensor::dims() const { return impl_->dims(); }
const phi::DDim &Tensor::dims() const { return impl_->dims(); }
std::vector<int64_t> Tensor::shape() const {
auto dims = impl_->dims();
......@@ -158,7 +158,7 @@ bool Tensor::is_string_tensor() const {
}
/* Part 3: Device and Backend methods */
Place Tensor::place() const {
const Place &Tensor::place() const {
PADDLE_ENFORCE_NOT_NULL(
impl_,
phi::errors::PermissionDenied(
......
......@@ -27,13 +27,13 @@ namespace paddle {
namespace experimental {
// declare cast api
Tensor cast(const Tensor &x, DataType out_dtype);
Tensor copy_to(const Tensor &x, Place place, bool blocking);
Tensor copy_to(const Tensor &x, const Place &place, bool blocking);
Tensor Tensor::cast(DataType target_type) const {
return experimental::cast(*this, target_type);
}
Tensor Tensor::copy_to(Place place, bool blocking) const {
Tensor Tensor::copy_to(const Place &place, bool blocking) const {
return experimental::copy_to(*this, place, blocking);
}
......
......@@ -105,7 +105,7 @@ class BaseAPI(object):
'double': 'double',
'bool': 'bool',
'str': 'const std::string&',
'Place': 'Place',
'Place': 'const Place&',
'DataLayout': 'DataLayout',
'DataType': 'DataType',
'int64_t[]': 'const std::vector<int64_t>&',
......@@ -120,7 +120,7 @@ class BaseAPI(object):
'float': 'paddle::optional<float>',
'double': 'paddle::optional<double>',
'bool': 'paddle::optional<bool>',
'Place': 'paddle::optional<Place>',
'Place': 'paddle::optional<const Place&>',
'DataLayout': 'paddle::optional<DataLayout>',
'DataType': 'paddle::optional<DataType>'
}
......@@ -328,7 +328,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
assert len(
vars_list
) == 2, f"{self.api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
assert (vars_list[0].strip() in self.attrs['names']) and (self.attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \
assert (vars_list[0].strip() in self.attrs['names']) and (self.attrs['attr_info'][vars_list[0].strip()][0] == 'const Place&'), \
f"{self.api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
backend_select_code = f"""
kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
......@@ -360,7 +360,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
attr_layout_count = 0
attr_data_type_count = 0
for attr_name in attrs['names']:
if attrs['attr_info'][attr_name][0] == 'Place':
if attrs['attr_info'][attr_name][0] == 'const Place&':
assert kernel['backend'] is not None, \
f"{api} api: When there is a parameter with 'Place' type in attributes, you must set backend of kernel manually."
attr_backend_count = attr_backend_count + 1
......@@ -420,7 +420,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
if len(input_names) == 0:
assert attr_backend_count > 0 and attr_data_type_count > 0, \
f"{api} api: When there is no input tensor, the args must have 'Backend' and 'DataType'."
f"{api} api: When there is no input tensor, the args must have 'Place' and 'DataType'."
kernel_select_args = ""
for input_name in input_names:
......
......@@ -225,7 +225,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name()}({self.args_s
assert len(
vars_list
) == 2, f"{api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
assert (vars_list[0].strip() in attrs['names']) and (attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \
assert (vars_list[0].strip() in attrs['names']) and (attrs['attr_info'][vars_list[0].strip()][0] == 'const Place&'), \
f"{api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
kernel_select_code = kernel_select_code + f"""
kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册