未验证 提交 715fd051 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add compatible reshape method for Tensor (#37281)

* add reshape method for Tensor

* fix typo

* fix typo

* fix conflit with develop
上级 6653ac5e
......@@ -163,9 +163,10 @@ class PD_DLL_DECL Tensor final {
/**
* @brief Reset the shape of the tensor.
* Reshape must be called before calling mutable_data() or
* copy_to(const PlaceType& place).
* This is a deprecated method and may be removed in the future!
* Note: This method means Reset the shape of the tensor,
* and must be called before calling mutable_data() or
* copy_to(const PlaceType& place), this is not a standard definition of
* reshape behavior, so we will deprecated this feature in the future.
*
* @param shape
*/
......
......@@ -105,9 +105,21 @@ std::vector<int64_t> Tensor::shape() const {
}
void Tensor::reshape(const std::vector<int64_t> &shape) {
LOG(WARNING) << "The function of resetting the shape of the uninitialized "
"Tensor of the `reshape` method is deprecated since version "
"2.3, and will be removed in version 2.4, please use "
"`paddle::experimental::full` method to create a new Tensor "
"instead. "
"reason: `reshape` means changing the tensor shape without "
"touching underlying data, this requires the total size of "
"the tensor to remain constant.";
if (detail::IsDenseTensor(impl_)) {
std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->set_meta(
pten::DenseTensorMeta(dtype(), framework::make_ddim(shape)));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The reshape operation is not supported now, "
"and it will be implemented by calling the reshape kernel later."));
"Only support reshape operation on DenseTensor now."));
}
}
DataType Tensor::dtype() const { return impl_->dtype(); }
......@@ -247,7 +259,7 @@ Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const {
end_idx))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Only supported slice operation on DenseTensor now."));
"Only support slice operation on DenseTensor now."));
}
}
......@@ -314,12 +326,10 @@ Tensor Tensor::cast(const DataType &target_type) const {
bool Tensor::defined() const { return impl_ != nullptr; }
bool Tensor::initialized() const {
return impl_ != nullptr && impl_->initialized();
}
bool Tensor::initialized() const { return defined() && impl_->initialized(); }
bool Tensor::is_initialized() const {
return impl_ != nullptr && impl_->initialized();
return defined() && impl_->initialized();
}
void Tensor::reset() { impl_.reset(); }
......
......@@ -27,6 +27,9 @@ PT_DECLARE_MODULE(ManipulationCPU);
PT_DECLARE_MODULE(ManipulationCUDA);
#endif
namespace pten {
namespace tests {
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
......@@ -68,3 +71,19 @@ TEST(API, reshape) {
}
ASSERT_EQ(value_equal, true);
}
TEST(Tensor, old_reshape) {
paddle::experimental::Tensor x(paddle::PlaceType::kCPU);
x.reshape({3, 4});
ASSERT_EQ(x.shape()[0], 3);
ASSERT_EQ(x.shape()[1], 4);
ASSERT_EQ(x.numel(), 12);
ASSERT_EQ(x.is_cpu(), true);
ASSERT_EQ(x.type(), pten::DataType::UNDEFINED);
ASSERT_EQ(x.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(x.initialized(), false);
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册