未验证 提交 b0d580a2 编写于 作者: S sneaxiy 提交者: GitHub

Fix inplace addto pass by setting dtype correctly (#37717)

* fix inplace addto pass

* update

* fix ut

* improve ci coverage

* fix musl ci compile error
上级 1a1aeff6
...@@ -39,14 +39,14 @@ ShareTensorBufferFunctor::ShareTensorBufferFunctor( ...@@ -39,14 +39,14 @@ ShareTensorBufferFunctor::ShareTensorBufferFunctor(
Scope *scope, size_t scope_idx, const std::string &op_type, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos, const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names, const bool &is_variant_scope, const std::vector<std::string> &out_var_names, const bool &is_variant_scope,
bool share_dims) bool share_dims_and_dtype)
: scope_(scope), : scope_(scope),
scope_idx_(scope_idx), scope_idx_(scope_idx),
op_type_(op_type), op_type_(op_type),
in_var_infos_(in_var_infos), in_var_infos_(in_var_infos),
out_var_names_(out_var_names), out_var_names_(out_var_names),
is_variant_scope_(is_variant_scope), is_variant_scope_(is_variant_scope),
share_dims_(share_dims) { share_dims_and_dtype_(share_dims_and_dtype) {
PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size(), PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The number of input variables and output variables " "The number of input variables and output variables "
...@@ -147,12 +147,14 @@ void ShareTensorBufferFunctor::operator()(Scope *exec_scope) { ...@@ -147,12 +147,14 @@ void ShareTensorBufferFunctor::operator()(Scope *exec_scope) {
// NOTE(zhiqiu): In the case of inplace addto, if the operator of // NOTE(zhiqiu): In the case of inplace addto, if the operator of
// the in_out_vars is skipped during running, we should set the dims of // the in_out_vars is skipped during running, we should set the dims of
// output as the same as input. // output as the same as input.
if (share_dims_) { if (share_dims_and_dtype_) {
out_tensor->Resize(in_tensor.dims()); out_tensor->Resize(in_tensor.dims());
out_tensor->ShareDataTypeWith(in_tensor);
} }
VLOG(2) << "Share tensor buffer when running " << op_type_ << " : " VLOG(2) << "Share tensor buffer when running " << op_type_ << " : "
<< in_var_info->Name() << " -> " << out_var_names_[i]; << in_var_info->Name() << " -> " << out_var_names_[i]
<< " share_dims_and_dtype = " << share_dims_and_dtype_;
} }
} }
} }
......
...@@ -73,12 +73,14 @@ class ShareTensorBufferFunctor { ...@@ -73,12 +73,14 @@ class ShareTensorBufferFunctor {
Scope *scope, size_t scope_idx, const std::string &op_type, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos, const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names, const std::vector<std::string> &out_var_names,
const bool &is_variant_scope, bool share_dims = false); const bool &is_variant_scope, bool share_dims_and_dtype = false);
void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info, void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info,
const std::string &out_var_name); const std::string &out_var_name);
void SetShareDims(bool share_dims) { share_dims_ = share_dims; } void SetShareDimsAndDtype(bool share_dims_and_dtype) {
share_dims_and_dtype_ = share_dims_and_dtype;
}
void operator()(Scope *exec_scope); void operator()(Scope *exec_scope);
...@@ -108,7 +110,7 @@ class ShareTensorBufferFunctor { ...@@ -108,7 +110,7 @@ class ShareTensorBufferFunctor {
// NOTE(zhiqiu): In the case of inplace addto, if the operator of // NOTE(zhiqiu): In the case of inplace addto, if the operator of
// the in_out_vars is skipped during running, we should set the dims of output // the in_out_vars is skipped during running, we should set the dims of output
// as the same as input. // as the same as input.
bool share_dims_{false}; bool share_dims_and_dtype_{false};
}; };
} // namespace details } // namespace details
......
...@@ -64,10 +64,10 @@ ComputationOpHandle *GetUniquePendingComputationOpHandle( ...@@ -64,10 +64,10 @@ ComputationOpHandle *GetUniquePendingComputationOpHandle(
ShareTensorBufferOpHandle::ShareTensorBufferOpHandle( ShareTensorBufferOpHandle::ShareTensorBufferOpHandle(
ir::Node *node, Scope *scope, size_t scope_idx, const std::string &op_type, ir::Node *node, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos, const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names, bool share_dims) const std::vector<std::string> &out_var_names, bool share_dims_and_dtype)
: OpHandleBase(node), : OpHandleBase(node),
functor_(scope, scope_idx, op_type, in_var_infos, out_var_names, functor_(scope, scope_idx, op_type, in_var_infos, out_var_names,
is_variant_scope_, share_dims) {} is_variant_scope_, share_dims_and_dtype) {}
std::unordered_map<std::string, std::string> std::unordered_map<std::string, std::string>
ShareTensorBufferOpHandle::ReusedVars() const { ShareTensorBufferOpHandle::ReusedVars() const {
...@@ -79,8 +79,9 @@ void ShareTensorBufferOpHandle::AddReuseVarPair( ...@@ -79,8 +79,9 @@ void ShareTensorBufferOpHandle::AddReuseVarPair(
functor_.AddReuseVarPair(in_var_info, out_var_name); functor_.AddReuseVarPair(in_var_info, out_var_name);
} }
void ShareTensorBufferOpHandle::SetShareDims(bool share_dims) { void ShareTensorBufferOpHandle::SetShareDimsAndDtype(
functor_.SetShareDims(share_dims); bool share_dims_and_dtype) {
functor_.SetShareDimsAndDtype(share_dims_and_dtype);
} }
void ShareTensorBufferOpHandle::InitCUDA() { void ShareTensorBufferOpHandle::InitCUDA() {
......
...@@ -56,7 +56,7 @@ class ShareTensorBufferOpHandle : public OpHandleBase { ...@@ -56,7 +56,7 @@ class ShareTensorBufferOpHandle : public OpHandleBase {
void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info, void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info,
const std::string &out_var_name); const std::string &out_var_name);
void SetShareDims(bool share_dims); void SetShareDimsAndDtype(bool share_dims_and_dtype);
const ShareTensorBufferFunctor &Functor() const { return functor_; } const ShareTensorBufferFunctor &Functor() const { return functor_; }
......
...@@ -283,7 +283,8 @@ void BufferSharedInplaceOpPass::ApplyImpl(ProgramDesc *main_program, ...@@ -283,7 +283,8 @@ void BufferSharedInplaceOpPass::ApplyImpl(ProgramDesc *main_program,
op->SetInput("X", inputs); op->SetInput("X", inputs);
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
op->SetOutput("XOut", inputs); // add necessary dependency op->SetOutput("XOut", inputs); // add necessary dependency
op->SetAttr("share_dims", std::vector<bool>(inputs.size(), false)); op->SetAttr("share_dims_and_dtype",
std::vector<bool>(inputs.size(), false));
} }
block->Flush(); block->Flush();
} }
......
...@@ -277,7 +277,7 @@ static void BuildInplaceAddToGraph(Node *in_var_0, Node *in_var_1, ...@@ -277,7 +277,7 @@ static void BuildInplaceAddToGraph(Node *in_var_0, Node *in_var_1,
grad_add_op_desc->SetInput("X", {in_var_1->Name()}); grad_add_op_desc->SetInput("X", {in_var_1->Name()});
grad_add_op_desc->SetOutput("Out", {out_var->Name()}); grad_add_op_desc->SetOutput("Out", {out_var->Name()});
grad_add_op_desc->SetOutput("XOut", {in_var_1->Name()}); grad_add_op_desc->SetOutput("XOut", {in_var_1->Name()});
grad_add_op_desc->SetAttr("share_dims", std::vector<bool>(1, true)); grad_add_op_desc->SetAttr("share_dims_and_dtype", std::vector<bool>(1, true));
// Add share_buffer op between in_var_0 and in_var_1 // Add share_buffer op between in_var_0 and in_var_1
OpDesc share_buffer_op; OpDesc share_buffer_op;
...@@ -285,7 +285,7 @@ static void BuildInplaceAddToGraph(Node *in_var_0, Node *in_var_1, ...@@ -285,7 +285,7 @@ static void BuildInplaceAddToGraph(Node *in_var_0, Node *in_var_1,
share_buffer_op.SetInput("X", {in_var_0->Name()}); share_buffer_op.SetInput("X", {in_var_0->Name()});
share_buffer_op.SetOutput("Out", {in_var_1->Name()}); share_buffer_op.SetOutput("Out", {in_var_1->Name()});
share_buffer_op.SetOutput("XOut", {in_var_0->Name()}); share_buffer_op.SetOutput("XOut", {in_var_0->Name()});
share_buffer_op.SetAttr("share_dims", std::vector<bool>(1, false)); share_buffer_op.SetAttr("share_dims_and_dtype", std::vector<bool>(1, false));
auto *new_share_buffer_op = graph->CreateOpNode(&share_buffer_op); auto *new_share_buffer_op = graph->CreateOpNode(&share_buffer_op);
new_share_buffer_op->inputs.push_back(in_var_0); new_share_buffer_op->inputs.push_back(in_var_0);
......
...@@ -329,7 +329,7 @@ bool MemoryReusePass::IsVarPairReusable( ...@@ -329,7 +329,7 @@ bool MemoryReusePass::IsVarPairReusable(
void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op, void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op,
details::VarHandle *in_var, details::VarHandle *in_var,
details::VarHandle *out_var, details::VarHandle *out_var,
bool share_dims) const { bool share_dims_and_dtype) const {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
(*var_infos_)[op->GetScopeIdx()].count(in_var->Name()), 0, (*var_infos_)[op->GetScopeIdx()].count(in_var->Name()), 0,
platform::errors::NotFound("Var(%s) does not in mem opt var infos.", platform::errors::NotFound("Var(%s) does not in mem opt var infos.",
...@@ -349,8 +349,8 @@ void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op, ...@@ -349,8 +349,8 @@ void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op,
share_buffer_op->AddInput(in_var); share_buffer_op->AddInput(in_var);
} }
if (share_dims) { if (share_dims_and_dtype) {
share_buffer_op->SetShareDims(true); share_buffer_op->SetShareDimsAndDtype(true);
} }
share_buffer_op->AddReuseVarPair( share_buffer_op->AddReuseVarPair(
......
...@@ -260,6 +260,8 @@ class Tensor { ...@@ -260,6 +260,8 @@ class Tensor {
// should not be copied. // should not be copied.
} }
void ShareDataTypeWith(const Tensor& tensor) { type_ = tensor.type_; }
bool IsSharedBufferWith(const Tensor& src) const { bool IsSharedBufferWith(const Tensor& src) const {
return holder_ && holder_ == src.Holder(); return holder_ && holder_ == src.Holder();
} }
......
...@@ -203,6 +203,7 @@ elseif(WITH_ROCM) ...@@ -203,6 +203,7 @@ elseif(WITH_ROCM)
else() else()
cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3) cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3)
endif() endif()
cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context share_buffer_op)
cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS}) cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS})
if (WITH_PYTHON) if (WITH_PYTHON)
......
...@@ -49,7 +49,8 @@ class ShareBufferOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,7 +49,8 @@ class ShareBufferOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), The output tensors which are the same as X. It is " "(Tensor), The output tensors which are the same as X. It is "
"used to build the graph dependency") "used to build the graph dependency")
.AsDuplicable(); .AsDuplicable();
AddAttr<std::vector<bool>>("share_dims", "Whether to share dims") AddAttr<std::vector<bool>>("share_dims_and_dtype",
"Whether to share dims and data type")
.SetDefault(std::vector<bool>()); .SetDefault(std::vector<bool>());
AddComment( AddComment(
R"DOC(Operator used to perform inplace memory reuse. It should be not exposed to Python APIs.)DOC"); R"DOC(Operator used to perform inplace memory reuse. It should be not exposed to Python APIs.)DOC");
......
...@@ -29,12 +29,13 @@ class ShareBufferOpKernel : public framework::OpKernel<T> { ...@@ -29,12 +29,13 @@ class ShareBufferOpKernel : public framework::OpKernel<T> {
size_t n = inputs.size(); size_t n = inputs.size();
PADDLE_ENFORCE_EQ(n, outputs.size(), platform::errors::PermissionDenied( PADDLE_ENFORCE_EQ(n, outputs.size(), platform::errors::PermissionDenied(
"Variable number not match.")); "Variable number not match."));
const auto &share_dims = ctx.Attr<std::vector<bool>>("share_dims"); const auto &share_dims_and_dtype =
if (!share_dims.empty()) { ctx.Attr<std::vector<bool>>("share_dims_and_dtype");
PADDLE_ENFORCE_EQ( if (!share_dims_and_dtype.empty()) {
n, share_dims.size(), PADDLE_ENFORCE_EQ(n, share_dims_and_dtype.size(),
platform::errors::PermissionDenied( platform::errors::PermissionDenied(
"Attribute share_dims number not match input variable number.")); "Attribute share_dims_and_dtype number not match "
"input variable number."));
} }
const std::vector<std::string> *input_args = nullptr, const std::vector<std::string> *input_args = nullptr,
...@@ -50,8 +51,9 @@ class ShareBufferOpKernel : public framework::OpKernel<T> { ...@@ -50,8 +51,9 @@ class ShareBufferOpKernel : public framework::OpKernel<T> {
outputs[i]->ShareBufferWith(*inputs[i]); outputs[i]->ShareBufferWith(*inputs[i]);
VLOG(10) << "Share tensor buffer " << (*input_args)[i] << " -> " VLOG(10) << "Share tensor buffer " << (*input_args)[i] << " -> "
<< (*output_args)[i]; << (*output_args)[i];
if (!share_dims.empty() && share_dims[i]) { if (!share_dims_and_dtype.empty() && share_dims_and_dtype[i]) {
outputs[i]->Resize(inputs[i]->dims()); outputs[i]->Resize(inputs[i]->dims());
outputs[i]->ShareDataTypeWith(*inputs[i]);
} }
} }
} }
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
USE_OP(share_buffer);
namespace paddle {
namespace framework {
TEST(test_share_buffer_op, test_share_buffer_op) {
std::vector<std::string> inputs = {"X1", "X2"};
std::vector<std::string> outputs = {"Y1", "Y2"};
std::vector<DDim> dims = {{2, 3, 4}, {5, 6}};
std::vector<bool> share_dims_and_dtype = {false, true};
size_t n = inputs.size();
EXPECT_EQ(n, outputs.size());
EXPECT_EQ(n, dims.size());
EXPECT_EQ(n, share_dims_and_dtype.size());
OpDesc desc;
desc.SetType("share_buffer");
desc.SetInput("X", inputs);
desc.SetOutput("Out", outputs);
desc.SetOutput("XOut", inputs);
desc.SetAttr("share_dims_and_dtype", share_dims_and_dtype);
auto op = OpRegistry::CreateOp(desc);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::Place place = platform::CUDAPlace(0);
#else
platform::Place place = platform::CPUPlace();
#endif
Scope scope;
for (size_t i = 0; i < n; ++i) {
auto *in_tensor = scope.Var(inputs[i])->GetMutable<LoDTensor>();
in_tensor->Resize(dims[i]);
in_tensor->mutable_data<float>(place);
scope.Var(outputs[i])->GetMutable<LoDTensor>();
}
op->Run(scope, place);
platform::DeviceContextPool::Instance().Get(place)->Wait();
for (size_t i = 0; i < n; ++i) {
const auto &in_tensor = scope.Var(inputs[i])->Get<LoDTensor>();
const auto &out_tensor = scope.Var(outputs[i])->Get<LoDTensor>();
EXPECT_TRUE(out_tensor.IsSharedBufferWith(in_tensor));
}
}
} // namespace framework
} // namespace paddle
...@@ -123,7 +123,7 @@ class TestIRPassBase(unittest.TestCase): ...@@ -123,7 +123,7 @@ class TestIRPassBase(unittest.TestCase):
if op.type != "share_buffer": if op.type != "share_buffer":
continue continue
share_dims = op.attr("share_dims") share_dims = op.attr("share_dims_and_dtype")
if share_dims: if share_dims:
for i in range(len(share_dims)): for i in range(len(share_dims)):
self.assertEqual(share_dims[0], share_dims[i]) self.assertEqual(share_dims[0], share_dims[i])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册