未验证 提交 d4b007af 编写于 作者: Y YuanRisheng 提交者: GitHub

add share dims (#40238)

上级 c39aa18e
......@@ -232,16 +232,8 @@ class CompatMetaTensor : public phi::MetaTensor {
}
}
void share_meta(const MetaTensor& meta_tensor) override {
void share_dims(const MetaTensor& meta_tensor) override {
set_dims(meta_tensor.dims());
set_dtype(meta_tensor.dtype());
// VarDesc doesn't contains layout, so we cannot share layout
// set_layout(meta_tensor.layout());
// special case 1: share lod of LoDTensor
share_lod(meta_tensor);
// special case 2: share height and rows of SelectedRows in runtime
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
if (var->IsType<phi::SelectedRows>()) {
......@@ -254,6 +246,16 @@ class CompatMetaTensor : public phi::MetaTensor {
}
}
void share_meta(const MetaTensor& meta_tensor) override {
set_dtype(meta_tensor.dtype());
// VarDesc doesn't contains layout, so we cannot share layout
// set_layout(meta_tensor.layout());
// special case 1: share lod of LoDTensor
share_lod(meta_tensor);
share_dims(meta_tensor);
}
private:
const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_);
......
......@@ -98,13 +98,9 @@ const LoD& MetaTensor::lod() const {
}
void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
if (phi::DenseTensor::classof(tensor_)) {
set_dims(meta_tensor.dims());
set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout());
share_lod(meta_tensor);
} else if (phi::SelectedRows::classof(tensor_)) {
set_dims(meta_tensor.dims());
if (phi::DenseTensor::classof(tensor_) ||
phi::SelectedRows::classof(tensor_)) {
share_dims(meta_tensor);
set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout());
share_lod(meta_tensor);
......@@ -114,4 +110,29 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
}
}
TensorBase* MetaTensor::get_tensor() const { return tensor_; }
void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
bool is_selected_rows = phi::SelectedRows::classof(tensor_);
if (is_dense_tensor || is_selected_rows) {
set_dims(meta_tensor.dims());
if (is_selected_rows) {
const auto in_tensor_base = meta_tensor.get_tensor();
PADDLE_ENFORCE_EQ(
phi::SelectedRows::classof(in_tensor_base),
true,
errors::InvalidArgument("The input MetaTensor is SelectedRows, but "
"the output MetaTensor is not this type."));
auto* selected_rows_out = static_cast<SelectedRows*>(tensor_);
auto* selected_rows_in = static_cast<SelectedRows*>(in_tensor_base);
selected_rows_out->set_rows(selected_rows_in->rows());
selected_rows_out->set_height(selected_rows_in->height());
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported sharing dims for `%s`.", tensor_->type_info().name()));
}
}
} // namespace phi
......@@ -60,12 +60,13 @@ class MetaTensor {
virtual void share_lod(const MetaTensor& meta_tensor);
virtual void share_meta(const MetaTensor& meta_tensor);
virtual void share_dims(const MetaTensor& meta_tensor);
private:
// Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods
const LoD& lod() const;
TensorBase* get_tensor() const;
TensorBase* tensor_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册