未验证 提交 715fd051 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add compatible reshape method for Tensor (#37281)

* add reshape method for Tensor

* fix typo

* fix typo

* fix conflit with develop
上级 6653ac5e
...@@ -163,9 +163,10 @@ class PD_DLL_DECL Tensor final { ...@@ -163,9 +163,10 @@ class PD_DLL_DECL Tensor final {
/** /**
* @brief Reset the shape of the tensor. * @brief Reset the shape of the tensor.
* Reshape must be called before calling mutable_data() or * Note: This method means Reset the shape of the tensor,
* copy_to(const PlaceType& place). * and must be called before calling mutable_data() or
* This is a deprecated method and may be removed in the future! * copy_to(const PlaceType& place), this is not a standard definition of
* reshape behavior, so we will deprecated this feature in the future.
* *
* @param shape * @param shape
*/ */
......
...@@ -105,9 +105,21 @@ std::vector<int64_t> Tensor::shape() const { ...@@ -105,9 +105,21 @@ std::vector<int64_t> Tensor::shape() const {
} }
void Tensor::reshape(const std::vector<int64_t> &shape) { void Tensor::reshape(const std::vector<int64_t> &shape) {
PADDLE_THROW(platform::errors::Unimplemented( LOG(WARNING) << "The function of resetting the shape of the uninitialized "
"The reshape operation is not supported now, " "Tensor of the `reshape` method is deprecated since version "
"and it will be implemented by calling the reshape kernel later.")); "2.3, and will be removed in version 2.4, please use "
"`paddle::experimental::full` method to create a new Tensor "
"instead. "
"reason: `reshape` means changing the tensor shape without "
"touching underlying data, this requires the total size of "
"the tensor to remain constant.";
if (detail::IsDenseTensor(impl_)) {
std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->set_meta(
pten::DenseTensorMeta(dtype(), framework::make_ddim(shape)));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Only support reshape operation on DenseTensor now."));
}
} }
DataType Tensor::dtype() const { return impl_->dtype(); } DataType Tensor::dtype() const { return impl_->dtype(); }
...@@ -247,7 +259,7 @@ Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const { ...@@ -247,7 +259,7 @@ Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const {
end_idx)))); end_idx))));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Only supported slice operation on DenseTensor now.")); "Only support slice operation on DenseTensor now."));
} }
} }
...@@ -314,12 +326,10 @@ Tensor Tensor::cast(const DataType &target_type) const { ...@@ -314,12 +326,10 @@ Tensor Tensor::cast(const DataType &target_type) const {
bool Tensor::defined() const { return impl_ != nullptr; } bool Tensor::defined() const { return impl_ != nullptr; }
bool Tensor::initialized() const { bool Tensor::initialized() const { return defined() && impl_->initialized(); }
return impl_ != nullptr && impl_->initialized();
}
bool Tensor::is_initialized() const { bool Tensor::is_initialized() const {
return impl_ != nullptr && impl_->initialized(); return defined() && impl_->initialized();
} }
void Tensor::reset() { impl_.reset(); } void Tensor::reset() { impl_.reset(); }
......
...@@ -27,6 +27,9 @@ PT_DECLARE_MODULE(ManipulationCPU); ...@@ -27,6 +27,9 @@ PT_DECLARE_MODULE(ManipulationCPU);
PT_DECLARE_MODULE(ManipulationCUDA); PT_DECLARE_MODULE(ManipulationCUDA);
#endif #endif
namespace pten {
namespace tests {
namespace framework = paddle::framework; namespace framework = paddle::framework;
using DDim = paddle::framework::DDim; using DDim = paddle::framework::DDim;
...@@ -68,3 +71,19 @@ TEST(API, reshape) { ...@@ -68,3 +71,19 @@ TEST(API, reshape) {
} }
ASSERT_EQ(value_equal, true); ASSERT_EQ(value_equal, true);
} }
TEST(Tensor, old_reshape) {
paddle::experimental::Tensor x(paddle::PlaceType::kCPU);
x.reshape({3, 4});
ASSERT_EQ(x.shape()[0], 3);
ASSERT_EQ(x.shape()[1], 4);
ASSERT_EQ(x.numel(), 12);
ASSERT_EQ(x.is_cpu(), true);
ASSERT_EQ(x.type(), pten::DataType::UNDEFINED);
ASSERT_EQ(x.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(x.initialized(), false);
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册