未验证 提交 878a40f5 编写于 作者: Z Zeng Jinle 提交者: GitHub

Support NoNeedBufferVarsInference in dygraph backward (#20868)

* support no need buffer vars in dygraph, test=develop

* fix inference compilation error, test=develop

* update no_need_buffer_vars_inference, test=develop

* add unittests for no_need_buffer_vars_context, test=develop

* refine no_need_buffer_vars by return ref, test=develop

* polish some codes, test=develop
上级 bf379fef
...@@ -191,13 +191,6 @@ copy(fluid_lib_dist ...@@ -191,13 +191,6 @@ copy(fluid_lib_dist
${src_dir}/${module}/ir/*.h ${src_dir}/${module}/fleet/*.h ${src_dir}/${module}/ir/*.h ${src_dir}/${module}/fleet/*.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}/ir/memory_optimize_pass ${dst_dir}/${module}/ir ${dst_dir}/${module}/fleet) DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}/ir/memory_optimize_pass ${dst_dir}/${module}/ir ${dst_dir}/${module}/fleet)
set(module "imperative")
copy(fluid_lib_dist
SRCS ${src_dir}/${module}/type_defs.h ${src_dir}/${module}/dygraph_grad_maker.h ${src_dir}/${module}/layer.h ${src_dir}/${module}/flags.h
DSTS ${dst_dir}/${module}/ ${dst_dir}/${module}/ ${dst_dir}/${module}/ ${dst_dir}/${module}/
)
set(module "operators") set(module "operators")
copy(fluid_lib_dist copy(fluid_lib_dist
SRCS ${src_dir}/${module}/reader/blocking_queue.h SRCS ${src_dir}/${module}/reader/blocking_queue.h
...@@ -224,6 +217,12 @@ copy(fluid_lib_dist ...@@ -224,6 +217,12 @@ copy(fluid_lib_dist
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/tinyformat DSTS ${dst_dir}/${module} ${dst_dir}/${module}/tinyformat
) )
set(module "imperative")
copy(fluid_lib_dist
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/jit/*.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/jit
)
set(module "pybind") set(module "pybind")
copy(fluid_lib_dist copy(fluid_lib_dist
SRCS ${CMAKE_CURRENT_BINARY_DIR}/paddle/fluid/${module}/pybind.h SRCS ${CMAKE_CURRENT_BINARY_DIR}/paddle/fluid/${module}/pybind.h
......
...@@ -118,9 +118,12 @@ cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc ...@@ -118,9 +118,12 @@ cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context) device_context)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(no_need_buffer_vars_inference SRCS no_need_buffer_vars_inference.cc DEPS attribute)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto no_need_buffer_vars_inference)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_test(no_need_buffer_vars_inference_test SRCS no_need_buffer_vars_inference_test.cc DEPS no_need_buffer_vars_inference layer)
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <type_traits> #include <type_traits>
...@@ -231,9 +232,9 @@ struct OpInfoFiller<T, kShapeInference> { ...@@ -231,9 +232,9 @@ struct OpInfoFiller<T, kShapeInference> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kInplaceOpInference> { struct OpInfoFiller<T, kInplaceOpInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->infer_inplace_ = [](const OpDesc& op_desc, bool use_cuda) { info->infer_inplace_ = [](bool use_cuda) {
T infer; T infer;
return infer(op_desc, use_cuda); return infer(use_cuda);
}; };
} }
}; };
...@@ -241,12 +242,7 @@ struct OpInfoFiller<T, kInplaceOpInference> { ...@@ -241,12 +242,7 @@ struct OpInfoFiller<T, kInplaceOpInference> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kNoNeedBufferVarsInference> { struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->infer_no_need_buffer_vars_ = [](const VariableNameMap& inputs, info->infer_no_need_buffer_vars_.Reset(std::make_shared<T>());
const VariableNameMap& outputs,
const AttributeMap& attrs) {
T infer(inputs, outputs, attrs);
return infer();
};
} }
}; };
......
...@@ -32,34 +32,30 @@ class InplaceOpInference { ...@@ -32,34 +32,30 @@ class InplaceOpInference {
public: public:
virtual ~InplaceOpInference() {} virtual ~InplaceOpInference() {}
virtual std::unordered_map<std::string, std::string> operator()( virtual std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, bool use_cuda) const = 0; bool use_cuda) const = 0;
};
/*
Inplace In and Out for operator only have an Input and an Output.
For example, activation op.
*/
class SingleOpInplaceInToOut : public InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, bool use_cuda) const override {
auto inputs = op_desc.InputNames();
auto outputs = op_desc.OutputNames();
PADDLE_ENFORCE_EQ(inputs.size(), 1, "Op inputs must be unique");
PADDLE_ENFORCE_EQ(outputs.size(), 1, "Op outputs must be unique");
return {{inputs[0], outputs[0]}};
}
}; };
#define DECLARE_INPLACE_OP_INFERER(class_name, ...) \ #define DECLARE_INPLACE_OP_INFERER(class_name, ...) \
class class_name final : public ::paddle::framework::InplaceOpInference { \ class class_name final : public ::paddle::framework::InplaceOpInference { \
public: \ public: \
std::unordered_map<std::string, std::string> operator()( \ std::unordered_map<std::string, std::string> operator()( \
const ::paddle::framework::OpDesc& op_desc, \
bool use_cuda) const final { \ bool use_cuda) const final { \
return {__VA_ARGS__}; \ return {__VA_ARGS__}; \
} \ } \
} }
#define DECLARE_CUDA_ONLY_INPLACE_OP_INFERER(class_name, ...) \
class class_name final : public ::paddle::framework::InplaceOpInference { \
public: \
std::unordered_map<std::string, std::string> operator()( \
bool use_cuda) const final { \
if (use_cuda) { \
return {__VA_ARGS__}; \
} else { \
return {}; \
} \
} \
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -81,7 +81,7 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const { ...@@ -81,7 +81,7 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const {
auto *op_desc = op->Node()->Op(); auto *op_desc = op->Node()->Op();
auto in_to_outs = auto in_to_outs =
OpInfoMap::Instance().Get(op_type).infer_inplace_(*op_desc, use_cuda); OpInfoMap::Instance().Get(op_type).infer_inplace_(use_cuda);
for (auto &pair : in_to_outs) { for (auto &pair : in_to_outs) {
auto &in_param = pair.first; auto &in_param = pair.first;
auto &in_args = op_desc->Input(in_param); auto &in_args = op_desc->Input(in_param);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include <string>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
const Attribute &InferNoNeedBufferVarsContext::GetAttr(
const std::string &name) const {
auto iter = attrs_.find(name);
PADDLE_ENFORCE_EQ(iter != attrs_.end(), true, "Cannot find attribute %s",
name);
return iter->second;
}
StaticGraphInferNoNeedBufferVarsContext::
StaticGraphInferNoNeedBufferVarsContext(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs)
: InferNoNeedBufferVarsContext(attrs), inputs_(inputs), outputs_(outputs) {}
bool StaticGraphInferNoNeedBufferVarsContext::HasOutput(
const std::string &slot) const {
auto iter = outputs_.find(slot);
if (iter != outputs_.end()) {
for (auto &var : iter->second) {
if (var != kEmptyVarName) return true;
}
}
return false;
}
DyGraphInferNoNeedBufferVarsContext::DyGraphInferNoNeedBufferVarsContext(
const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs, const AttributeMap &attrs)
: InferNoNeedBufferVarsContext(attrs), inputs_(inputs), outputs_(outputs) {}
bool DyGraphInferNoNeedBufferVarsContext::HasOutput(
const std::string &slot) const {
auto iter = outputs_.find(slot);
if (iter != outputs_.end()) {
for (auto &var : iter->second) {
if (var) return true;
}
}
return false;
}
} // namespace framework
} // namespace paddle
...@@ -14,48 +14,119 @@ ...@@ -14,48 +14,119 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class NoNeedBufferVarsInference { class InferNoNeedBufferVarsContext {
public: public:
NoNeedBufferVarsInference(const VariableNameMap &inputs, explicit InferNoNeedBufferVarsContext(const framework::AttributeMap &attrs)
const VariableNameMap &outputs, : attrs_(attrs) {}
const AttributeMap &attrs) virtual ~InferNoNeedBufferVarsContext() = default;
: inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
virtual ~NoNeedBufferVarsInference() = default; virtual bool HasOutput(const std::string &slot) const = 0;
const VariableNameMap &Inputs() const { return inputs_; } const Attribute &GetAttr(const std::string &attr) const;
const VariableNameMap &Outputs() const { return outputs_; } private:
const framework::AttributeMap &attrs_;
};
const AttributeMap &Attrs() const { return attrs_; } class StaticGraphInferNoNeedBufferVarsContext final
: public InferNoNeedBufferVarsContext {
public:
StaticGraphInferNoNeedBufferVarsContext(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs);
virtual std::unordered_set<std::string> operator()() const = 0; bool HasOutput(const std::string &slot) const final;
private: private:
const VariableNameMap &inputs_; const VariableNameMap &inputs_;
const VariableNameMap &outputs_; const VariableNameMap &outputs_;
const AttributeMap &attrs_;
}; };
#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \ class DyGraphInferNoNeedBufferVarsContext final
class class_type final \ : public InferNoNeedBufferVarsContext {
: public ::paddle::framework::NoNeedBufferVarsInference { \ public:
public: \ DyGraphInferNoNeedBufferVarsContext(const imperative::NameVarBaseMap &inputs,
using ::paddle::framework::NoNeedBufferVarsInference:: \ const imperative::NameVarBaseMap &outputs,
NoNeedBufferVarsInference; \ const AttributeMap &attr);
\
std::unordered_set<std::string> operator()() const final { \ bool HasOutput(const std::string &slot) const final;
return {__VA_ARGS__}; \
} \ private:
const imperative::NameVarBaseMap &inputs_;
const imperative::NameVarBaseMap &outputs_;
};
class NoNeedBufferVarsInference {
public:
virtual ~NoNeedBufferVarsInference() = default;
virtual const std::unordered_set<std::string> &operator()(
const InferNoNeedBufferVarsContext &ctx) const = 0;
protected:
static const std::unordered_set<std::string> &Empty() {
static std::unordered_set<std::string> empty;
return empty;
}
};
#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \
class class_type final \
: public ::paddle::framework::NoNeedBufferVarsInference { \
public: \
using ::paddle::framework::NoNeedBufferVarsInference:: \
NoNeedBufferVarsInference; \
\
const std::unordered_set<std::string> &operator()( \
const ::paddle::framework::InferNoNeedBufferVarsContext &ctx) \
const final { \
static std::unordered_set<std::string> __ret__{__VA_ARGS__}; \
return __ret__; \
} \
}
class InferNoNeedBufferVarsFN {
public:
inline const std::unordered_set<std::string> &operator()(
const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs) const {
PADDLE_ENFORCE_NOT_NULL(inferer_);
StaticGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
return (*inferer_)(ctx);
}
inline const std::unordered_set<std::string> &operator()(
const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs,
const AttributeMap &attrs) const {
PADDLE_ENFORCE_NOT_NULL(inferer_);
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
return (*inferer_)(ctx);
} }
inline operator bool() const { return inferer_ != nullptr; }
inline bool operator!() const { return inferer_ == nullptr; }
inline void Reset(const std::shared_ptr<NoNeedBufferVarsInference> &inferer) {
PADDLE_ENFORCE_NOT_NULL(inferer);
PADDLE_ENFORCE_EQ(inferer_, nullptr);
inferer_ = inferer;
}
private:
std::shared_ptr<NoNeedBufferVarsInference> inferer_;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/imperative/layer.h"
namespace paddle {
namespace framework {
TEST(test_no_need_buffer_vars_inference, test_static_graph) {
AttributeMap attrs{{"is_test", true}};
VariableNameMap inputs;
VariableNameMap outputs{{"Out", {kEmptyVarName, "tmp_0"}}};
StaticGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
ASSERT_TRUE(ctx.HasOutput("Out"));
ASSERT_FALSE(ctx.HasOutput("X"));
ASSERT_TRUE(boost::get<bool>(ctx.GetAttr("is_test")));
}
TEST(test_no_need_buffer_vars_inference, test_dygraph) {
AttributeMap attrs{{"is_test", true}};
imperative::NameVarBaseMap inputs;
imperative::NameVarBaseMap outputs;
outputs["Out"].emplace_back(nullptr);
outputs["Out"].emplace_back(new imperative::VarBase("tmp_0"));
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
ASSERT_TRUE(ctx.HasOutput("Out"));
ASSERT_FALSE(ctx.HasOutput("X"));
ASSERT_TRUE(boost::get<bool>(ctx.GetAttr("is_test")));
}
} // namespace framework
} // namespace paddle
...@@ -233,15 +233,17 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -233,15 +233,17 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
std::stringstream ss; std::stringstream ss;
ss << "Op(" << type_ << "), inputs:{"; ss << "Op(" << type_ << "), inputs:{";
std::unordered_set<std::string> no_need_buffer_vars; const std::unordered_set<std::string>* no_need_buffer_vars = nullptr;
if (info_ && info_->NoNeedBufferVarsInferer()) { if (info_ && info_->NoNeedBufferVarsInferer()) {
no_need_buffer_vars = no_need_buffer_vars =
Info().NoNeedBufferVarsInferer()(Inputs(), Outputs(), Attrs()); &(Info().NoNeedBufferVarsInferer()(Inputs(), Outputs(), Attrs()));
if (no_need_buffer_vars->empty()) no_need_buffer_vars = nullptr;
} }
for (auto it = inputs_.begin(); it != inputs_.end();) { for (auto it = inputs_.begin(); it != inputs_.end();) {
auto& input = *it; auto& input = *it;
bool is_no_need_buffer_var = (no_need_buffer_vars.count(input.first) > 0); bool is_no_need_buffer_var =
(no_need_buffer_vars && no_need_buffer_vars->count(input.first) > 0);
ss << input.first << "["; ss << input.first << "[";
for (size_t i = 0; i < input.second.size(); ++i) { for (size_t i = 0; i < input.second.size(); ++i) {
auto var_name = input.second[i]; auto var_name = input.second[i];
...@@ -1027,21 +1029,18 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -1027,21 +1029,18 @@ Scope* OperatorWithKernel::PrepareData(
RuntimeContext* ctx) const { RuntimeContext* ctx) const {
Scope* new_scope = nullptr; Scope* new_scope = nullptr;
std::unordered_set<std::string> no_buffer_ins; const std::unordered_set<std::string>* no_buffer_ins = nullptr;
if (info_) { if (info_) {
auto& no_buffer_inferer = info_->NoNeedBufferVarsInferer(); auto& no_buffer_inferer = info_->NoNeedBufferVarsInferer();
// Some op may not register NoNeedBufferVarsInferer // Some op may not register NoNeedBufferVarsInferer
if (no_buffer_inferer) { if (no_buffer_inferer) {
no_buffer_ins = no_buffer_inferer(Inputs(), Outputs(), Attrs()); no_buffer_ins = &(no_buffer_inferer(Inputs(), Outputs(), Attrs()));
if (no_buffer_ins->empty()) no_buffer_ins = nullptr;
} }
} }
for (auto& var_name_item : Inputs()) { for (auto& var_name_item : Inputs()) {
// NOTE(zjl): STL does not guarantee fast std::unordered_set::count when set if (no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0) {
// is empty. At least STL implemented on my mac does calculate hash code
// of search key even though the set is empty.
if (!no_buffer_ins.empty() &&
no_buffer_ins.count(var_name_item.first) > 0) {
VLOG(7) << "Skip scanning input " << var_name_item.first VLOG(7) << "Skip scanning input " << var_name_item.first
<< " in Operator " << type_; << " in Operator " << type_;
continue; continue;
......
...@@ -31,7 +31,7 @@ class InferShapeContext; ...@@ -31,7 +31,7 @@ class InferShapeContext;
class InferVarTypeContext; class InferVarTypeContext;
class BlockDesc; class BlockDesc;
class Variable; class Variable;
class NoNeedBufferVarsInference; class InferNoNeedBufferVarsFN;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// TODO(panyx0718): Replace vector with something like gtl::Vector. // TODO(panyx0718): Replace vector with something like gtl::Vector.
...@@ -67,11 +67,7 @@ using InferVarTypeFN = ...@@ -67,11 +67,7 @@ using InferVarTypeFN =
using InferShapeFN = std::function<void(InferShapeContext*)>; using InferShapeFN = std::function<void(InferShapeContext*)>;
using InplacePair = std::unordered_map<std::string, std::string>; using InplacePair = std::unordered_map<std::string, std::string>;
using InferInplaceOpFN = std::function<InplacePair(const OpDesc&, bool)>; using InferInplaceOpFN = std::function<InplacePair(bool /*use_cuda*/)>;
using InferNoNeedBufferVarsFN = std::function<std::unordered_set<std::string>(
const VariableNameMap& /*inputs*/, const VariableNameMap& /*outputs*/,
const AttributeMap& /*attrs*/)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,36 @@ ...@@ -20,6 +20,36 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
static void ClearNoNeedBufferInputs(OpBase* op) {
auto& inferer = op->Info().NoNeedBufferVarsInferer();
if (!inferer) return;
auto* ins = op->GetMutableInsMap();
const auto& no_need_buffer_slots =
inferer(*ins, op->GetOutsMap(), op->Attrs());
if (no_need_buffer_slots.empty()) return;
for (auto& slot : no_need_buffer_slots) {
auto iter = ins->find(slot);
if (iter == ins->end()) continue;
VLOG(2) << "Clear data buffer of " << slot << " in " << op->Type();
for (auto& each_var : iter->second) {
if (!each_var) continue;
auto& var = each_var->Var();
PADDLE_ENFORCE_EQ(var.IsType<framework::LoDTensor>(), true,
"Only support LoDTensor");
// TODO(zjl): support higher order derivatives
auto new_var = new VarBase(false, each_var->Name());
auto* new_tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto& old_tensor = var.Get<framework::LoDTensor>();
new_tensor->Resize(old_tensor.dims());
each_var.reset(new_var);
}
}
}
static std::vector<std::unique_ptr<OpBase>> CreateGradOpBases( static std::vector<std::unique_ptr<OpBase>> CreateGradOpBases(
const OpBase* fw_op_base, const NameVarBaseMap& in, const OpBase* fw_op_base, const NameVarBaseMap& in,
const NameVarBaseMap& out) { const NameVarBaseMap& out) {
...@@ -151,6 +181,7 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op, ...@@ -151,6 +181,7 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
// this OpBase* is just used to manage op's life time // this OpBase* is just used to manage op's life time
engine_->InsertOp(grad_op.get(), grad_op); engine_->InsertOp(grad_op.get(), grad_op);
ClearNoNeedBufferInputs(grad_op.get());
} }
} }
......
...@@ -93,6 +93,7 @@ if (WITH_GPU) ...@@ -93,6 +93,7 @@ if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
endif() endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer)
# FIXME(typhoonzero): operator deps may not needed. # FIXME(typhoonzero): operator deps may not needed.
# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) # op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
......
...@@ -888,6 +888,7 @@ class PowOpGrad : public framework::OperatorWithKernel { ...@@ -888,6 +888,7 @@ class PowOpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
}; };
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -903,8 +904,7 @@ namespace plat = paddle::platform; ...@@ -903,8 +904,7 @@ namespace plat = paddle::platform;
ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(), \ ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(), \
paddle::imperative::OpBase>, \ paddle::imperative::OpBase>, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \ std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \ ops::ActFwdInplaceInferer, void>::type); \
void>::type); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad, \ REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
ops::ActivationGradOpInplaceInference); ops::ActivationGradOpInplaceInference);
...@@ -932,7 +932,7 @@ REGISTER_OPERATOR( ...@@ -932,7 +932,7 @@ REGISTER_OPERATOR(
paddle::framework::OpDesc>, paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(), ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>, paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut); ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInference,
ops::ReluDoubleGradMaker<paddle::framework::OpDesc>, ops::ReluDoubleGradMaker<paddle::framework::OpDesc>,
...@@ -962,7 +962,7 @@ REGISTER_OPERATOR( ...@@ -962,7 +962,7 @@ REGISTER_OPERATOR(
paddle::framework::OpDesc>, paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(), ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>, paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut); ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInference,
ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>, ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>,
...@@ -991,7 +991,7 @@ REGISTER_OPERATOR( ...@@ -991,7 +991,7 @@ REGISTER_OPERATOR(
paddle::framework::OpDesc>, paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(), ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>, paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut); ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInference,
ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>, ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>,
...@@ -1019,7 +1019,7 @@ REGISTER_OPERATOR( ...@@ -1019,7 +1019,7 @@ REGISTER_OPERATOR(
paddle::framework::OpDesc>, paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(), ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>, paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut); ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInference,
ops::SquareDoubleGradMaker<paddle::framework::OpDesc>, ops::SquareDoubleGradMaker<paddle::framework::OpDesc>,
...@@ -1049,7 +1049,7 @@ REGISTER_OPERATOR( ...@@ -1049,7 +1049,7 @@ REGISTER_OPERATOR(
ops::PowGradOpMaker<paddle::framework::OpDesc>, ops::PowGradOpMaker<paddle::framework::OpDesc>,
ops::PowGradOpMaker<paddle::imperative::OpBase>, ops::PowGradOpMaker<paddle::imperative::OpBase>,
std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(), std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
::paddle::framework::SingleOpInplaceInToOut, void>::type); ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad, REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
ops::ActivationGradOpInplaceInference); ops::ActivationGradOpInplaceInference);
......
...@@ -298,24 +298,14 @@ class AffineChannelNoNeedBufferVarsInference ...@@ -298,24 +298,14 @@ class AffineChannelNoNeedBufferVarsInference
public: public:
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference; using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
private: const std::unordered_set<std::string>& operator()(
inline bool HasOutput(const std::string& name) const { const framework::InferNoNeedBufferVarsContext& ctx) const final {
auto& outputs = Outputs(); static const std::unordered_set<std::string> kX({"X"});
auto iter = outputs.find(name); if (!ctx.HasOutput(framework::GradVarName("Scale")) &&
if (iter == outputs.end() || iter->second.empty()) { !ctx.HasOutput(framework::GradVarName("Bias"))) {
return false; return kX;
} else { } else {
return iter->second[0] != framework::kEmptyVarName; return Empty();
}
}
public:
std::unordered_set<std::string> operator()() const override {
if (!HasOutput(framework::GradVarName("Scale")) &&
!HasOutput(framework::GradVarName("Bias"))) {
return {"X"};
} else {
return {};
} }
} }
}; };
......
...@@ -98,7 +98,7 @@ class ScaleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -98,7 +98,7 @@ class ScaleGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
using ScaleOpInplace = framework::SingleOpInplaceInToOut; DECLARE_INPLACE_OP_INFERER(ScaleOpInplace, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -27,7 +27,7 @@ class SequenceConvOp : public framework::OperatorWithKernel { ...@@ -27,7 +27,7 @@ class SequenceConvOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceConvOp should not be null."); "Input(X) of SequenceConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"), PADDLE_ENFORCE(ctx->HasInput("Filter"),
...@@ -82,7 +82,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel { ...@@ -82,7 +82,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of output(Out) should not be null."); "Gradient of output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "The input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "The input(X) should not be null.");
...@@ -209,11 +209,13 @@ class SequenceConvGradNoNeedBufferVarsInference ...@@ -209,11 +209,13 @@ class SequenceConvGradNoNeedBufferVarsInference
public: public:
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference; using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
std::unordered_set<std::string> operator()() const override { const std::unordered_set<std::string> &operator()(
if (!boost::get<bool>(Attrs().at("paddingTrainable"))) { const framework::InferNoNeedBufferVarsContext &ctx) const final {
return {"PaddingData"}; static const std::unordered_set<std::string> kPaddingData({"PaddingData"});
if (!boost::get<bool>(ctx.GetAttr("paddingTrainable"))) {
return kPaddingData;
} else { } else {
return {}; return Empty();
} }
} }
}; };
......
...@@ -221,20 +221,9 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -221,20 +221,9 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpMaker<T> {
DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"});
class SoftmaxGradInplaceInferer final : public framework::InplaceOpInference { // NOTE(zjl): AVX implementation of SoftmaxGrad does not support in-place
public: DECLARE_CUDA_ONLY_INPLACE_OP_INFERER(SoftmaxGradInplaceInferer,
using framework::InplaceOpInference::InplaceOpInference; {"Out", framework::GradVarName("X")});
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const final {
if (use_cuda) {
return {{"Out", framework::GradVarName("X")}};
} else {
// NOTE(zjl): AVX implementation of SoftmaxGrad does not support in-place
return {};
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册