未验证 提交 878a40f5 编写于 作者: Z Zeng Jinle 提交者: GitHub

Support NoNeedBufferVarsInference in dygraph backward (#20868)

* support no need buffer vars in dygraph, test=develop

* fix inference compilation error, test=develop

* update no_need_buffer_vars_inference, test=develop

* add unittests for no_need_buffer_vars_context, test=develop

* refine no_need_buffer_vars by return ref, test=develop

* polish some codes, test=develop
上级 bf379fef
......@@ -191,13 +191,6 @@ copy(fluid_lib_dist
${src_dir}/${module}/ir/*.h ${src_dir}/${module}/fleet/*.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}/ir/memory_optimize_pass ${dst_dir}/${module}/ir ${dst_dir}/${module}/fleet)
set(module "imperative")
copy(fluid_lib_dist
SRCS ${src_dir}/${module}/type_defs.h ${src_dir}/${module}/dygraph_grad_maker.h ${src_dir}/${module}/layer.h ${src_dir}/${module}/flags.h
DSTS ${dst_dir}/${module}/ ${dst_dir}/${module}/ ${dst_dir}/${module}/ ${dst_dir}/${module}/
)
set(module "operators")
copy(fluid_lib_dist
SRCS ${src_dir}/${module}/reader/blocking_queue.h
......@@ -224,6 +217,12 @@ copy(fluid_lib_dist
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/tinyformat
)
set(module "imperative")
copy(fluid_lib_dist
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/jit/*.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/jit
)
set(module "pybind")
copy(fluid_lib_dist
SRCS ${CMAKE_CURRENT_BINARY_DIR}/paddle/fluid/${module}/pybind.h
......
......@@ -118,9 +118,12 @@ cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(no_need_buffer_vars_inference SRCS no_need_buffer_vars_inference.cc DEPS attribute)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto no_need_buffer_vars_inference)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_test(no_need_buffer_vars_inference_test SRCS no_need_buffer_vars_inference_test.cc DEPS no_need_buffer_vars_inference layer)
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <type_traits>
......@@ -231,9 +232,9 @@ struct OpInfoFiller<T, kShapeInference> {
template <typename T>
struct OpInfoFiller<T, kInplaceOpInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_inplace_ = [](const OpDesc& op_desc, bool use_cuda) {
info->infer_inplace_ = [](bool use_cuda) {
T infer;
return infer(op_desc, use_cuda);
return infer(use_cuda);
};
}
};
......@@ -241,12 +242,7 @@ struct OpInfoFiller<T, kInplaceOpInference> {
template <typename T>
struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_no_need_buffer_vars_ = [](const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs) {
T infer(inputs, outputs, attrs);
return infer();
};
info->infer_no_need_buffer_vars_.Reset(std::make_shared<T>());
}
};
......
......@@ -32,32 +32,28 @@ class InplaceOpInference {
public:
virtual ~InplaceOpInference() {}
virtual std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, bool use_cuda) const = 0;
bool use_cuda) const = 0;
};
/*
Inplace In and Out for operator only have an Input and an Output.
For example, activation op.
*/
class SingleOpInplaceInToOut : public InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, bool use_cuda) const override {
auto inputs = op_desc.InputNames();
auto outputs = op_desc.OutputNames();
PADDLE_ENFORCE_EQ(inputs.size(), 1, "Op inputs must be unique");
PADDLE_ENFORCE_EQ(outputs.size(), 1, "Op outputs must be unique");
return {{inputs[0], outputs[0]}};
#define DECLARE_INPLACE_OP_INFERER(class_name, ...) \
class class_name final : public ::paddle::framework::InplaceOpInference { \
public: \
std::unordered_map<std::string, std::string> operator()( \
bool use_cuda) const final { \
return {__VA_ARGS__}; \
} \
}
};
#define DECLARE_INPLACE_OP_INFERER(class_name, ...) \
#define DECLARE_CUDA_ONLY_INPLACE_OP_INFERER(class_name, ...) \
class class_name final : public ::paddle::framework::InplaceOpInference { \
public: \
std::unordered_map<std::string, std::string> operator()( \
const ::paddle::framework::OpDesc& op_desc, \
bool use_cuda) const final { \
if (use_cuda) { \
return {__VA_ARGS__}; \
} else { \
return {}; \
} \
} \
}
......
......@@ -81,7 +81,7 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const {
auto *op_desc = op->Node()->Op();
auto in_to_outs =
OpInfoMap::Instance().Get(op_type).infer_inplace_(*op_desc, use_cuda);
OpInfoMap::Instance().Get(op_type).infer_inplace_(use_cuda);
for (auto &pair : in_to_outs) {
auto &in_param = pair.first;
auto &in_args = op_desc->Input(in_param);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include <string>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
const Attribute &InferNoNeedBufferVarsContext::GetAttr(
const std::string &name) const {
auto iter = attrs_.find(name);
PADDLE_ENFORCE_EQ(iter != attrs_.end(), true, "Cannot find attribute %s",
name);
return iter->second;
}
StaticGraphInferNoNeedBufferVarsContext::
StaticGraphInferNoNeedBufferVarsContext(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs)
: InferNoNeedBufferVarsContext(attrs), inputs_(inputs), outputs_(outputs) {}
bool StaticGraphInferNoNeedBufferVarsContext::HasOutput(
const std::string &slot) const {
auto iter = outputs_.find(slot);
if (iter != outputs_.end()) {
for (auto &var : iter->second) {
if (var != kEmptyVarName) return true;
}
}
return false;
}
DyGraphInferNoNeedBufferVarsContext::DyGraphInferNoNeedBufferVarsContext(
const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs, const AttributeMap &attrs)
: InferNoNeedBufferVarsContext(attrs), inputs_(inputs), outputs_(outputs) {}
bool DyGraphInferNoNeedBufferVarsContext::HasOutput(
const std::string &slot) const {
auto iter = outputs_.find(slot);
if (iter != outputs_.end()) {
for (auto &var : iter->second) {
if (var) return true;
}
}
return false;
}
} // namespace framework
} // namespace paddle
......@@ -14,35 +14,70 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
class NoNeedBufferVarsInference {
class InferNoNeedBufferVarsContext {
public:
NoNeedBufferVarsInference(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs)
: inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
explicit InferNoNeedBufferVarsContext(const framework::AttributeMap &attrs)
: attrs_(attrs) {}
virtual ~InferNoNeedBufferVarsContext() = default;
virtual ~NoNeedBufferVarsInference() = default;
virtual bool HasOutput(const std::string &slot) const = 0;
const VariableNameMap &Inputs() const { return inputs_; }
const Attribute &GetAttr(const std::string &attr) const;
const VariableNameMap &Outputs() const { return outputs_; }
private:
const framework::AttributeMap &attrs_;
};
const AttributeMap &Attrs() const { return attrs_; }
class StaticGraphInferNoNeedBufferVarsContext final
: public InferNoNeedBufferVarsContext {
public:
StaticGraphInferNoNeedBufferVarsContext(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs);
virtual std::unordered_set<std::string> operator()() const = 0;
bool HasOutput(const std::string &slot) const final;
private:
const VariableNameMap &inputs_;
const VariableNameMap &outputs_;
const AttributeMap &attrs_;
};
class DyGraphInferNoNeedBufferVarsContext final
: public InferNoNeedBufferVarsContext {
public:
DyGraphInferNoNeedBufferVarsContext(const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs,
const AttributeMap &attr);
bool HasOutput(const std::string &slot) const final;
private:
const imperative::NameVarBaseMap &inputs_;
const imperative::NameVarBaseMap &outputs_;
};
class NoNeedBufferVarsInference {
public:
virtual ~NoNeedBufferVarsInference() = default;
virtual const std::unordered_set<std::string> &operator()(
const InferNoNeedBufferVarsContext &ctx) const = 0;
protected:
static const std::unordered_set<std::string> &Empty() {
static std::unordered_set<std::string> empty;
return empty;
}
};
#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \
......@@ -52,10 +87,46 @@ class NoNeedBufferVarsInference {
using ::paddle::framework::NoNeedBufferVarsInference:: \
NoNeedBufferVarsInference; \
\
std::unordered_set<std::string> operator()() const final { \
return {__VA_ARGS__}; \
const std::unordered_set<std::string> &operator()( \
const ::paddle::framework::InferNoNeedBufferVarsContext &ctx) \
const final { \
static std::unordered_set<std::string> __ret__{__VA_ARGS__}; \
return __ret__; \
} \
}
class InferNoNeedBufferVarsFN {
public:
inline const std::unordered_set<std::string> &operator()(
const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs) const {
PADDLE_ENFORCE_NOT_NULL(inferer_);
StaticGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
return (*inferer_)(ctx);
}
inline const std::unordered_set<std::string> &operator()(
const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs,
const AttributeMap &attrs) const {
PADDLE_ENFORCE_NOT_NULL(inferer_);
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
return (*inferer_)(ctx);
}
inline operator bool() const { return inferer_ != nullptr; }
inline bool operator!() const { return inferer_ == nullptr; }
inline void Reset(const std::shared_ptr<NoNeedBufferVarsInference> &inferer) {
PADDLE_ENFORCE_NOT_NULL(inferer);
PADDLE_ENFORCE_EQ(inferer_, nullptr);
inferer_ = inferer;
}
private:
std::shared_ptr<NoNeedBufferVarsInference> inferer_;
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/imperative/layer.h"
namespace paddle {
namespace framework {
TEST(test_no_need_buffer_vars_inference, test_static_graph) {
AttributeMap attrs{{"is_test", true}};
VariableNameMap inputs;
VariableNameMap outputs{{"Out", {kEmptyVarName, "tmp_0"}}};
StaticGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
ASSERT_TRUE(ctx.HasOutput("Out"));
ASSERT_FALSE(ctx.HasOutput("X"));
ASSERT_TRUE(boost::get<bool>(ctx.GetAttr("is_test")));
}
TEST(test_no_need_buffer_vars_inference, test_dygraph) {
AttributeMap attrs{{"is_test", true}};
imperative::NameVarBaseMap inputs;
imperative::NameVarBaseMap outputs;
outputs["Out"].emplace_back(nullptr);
outputs["Out"].emplace_back(new imperative::VarBase("tmp_0"));
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
ASSERT_TRUE(ctx.HasOutput("Out"));
ASSERT_FALSE(ctx.HasOutput("X"));
ASSERT_TRUE(boost::get<bool>(ctx.GetAttr("is_test")));
}
} // namespace framework
} // namespace paddle
......@@ -233,15 +233,17 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
std::stringstream ss;
ss << "Op(" << type_ << "), inputs:{";
std::unordered_set<std::string> no_need_buffer_vars;
const std::unordered_set<std::string>* no_need_buffer_vars = nullptr;
if (info_ && info_->NoNeedBufferVarsInferer()) {
no_need_buffer_vars =
Info().NoNeedBufferVarsInferer()(Inputs(), Outputs(), Attrs());
&(Info().NoNeedBufferVarsInferer()(Inputs(), Outputs(), Attrs()));
if (no_need_buffer_vars->empty()) no_need_buffer_vars = nullptr;
}
for (auto it = inputs_.begin(); it != inputs_.end();) {
auto& input = *it;
bool is_no_need_buffer_var = (no_need_buffer_vars.count(input.first) > 0);
bool is_no_need_buffer_var =
(no_need_buffer_vars && no_need_buffer_vars->count(input.first) > 0);
ss << input.first << "[";
for (size_t i = 0; i < input.second.size(); ++i) {
auto var_name = input.second[i];
......@@ -1027,21 +1029,18 @@ Scope* OperatorWithKernel::PrepareData(
RuntimeContext* ctx) const {
Scope* new_scope = nullptr;
std::unordered_set<std::string> no_buffer_ins;
const std::unordered_set<std::string>* no_buffer_ins = nullptr;
if (info_) {
auto& no_buffer_inferer = info_->NoNeedBufferVarsInferer();
// Some op may not register NoNeedBufferVarsInferer
if (no_buffer_inferer) {
no_buffer_ins = no_buffer_inferer(Inputs(), Outputs(), Attrs());
no_buffer_ins = &(no_buffer_inferer(Inputs(), Outputs(), Attrs()));
if (no_buffer_ins->empty()) no_buffer_ins = nullptr;
}
}
for (auto& var_name_item : Inputs()) {
// NOTE(zjl): STL does not guarantee fast std::unordered_set::count when set
// is empty. At least STL implemented on my mac does calculate hash code
// of search key even though the set is empty.
if (!no_buffer_ins.empty() &&
no_buffer_ins.count(var_name_item.first) > 0) {
if (no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0) {
VLOG(7) << "Skip scanning input " << var_name_item.first
<< " in Operator " << type_;
continue;
......
......@@ -31,7 +31,7 @@ class InferShapeContext;
class InferVarTypeContext;
class BlockDesc;
class Variable;
class NoNeedBufferVarsInference;
class InferNoNeedBufferVarsFN;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// TODO(panyx0718): Replace vector with something like gtl::Vector.
......@@ -67,11 +67,7 @@ using InferVarTypeFN =
using InferShapeFN = std::function<void(InferShapeContext*)>;
using InplacePair = std::unordered_map<std::string, std::string>;
using InferInplaceOpFN = std::function<InplacePair(const OpDesc&, bool)>;
using InferNoNeedBufferVarsFN = std::function<std::unordered_set<std::string>(
const VariableNameMap& /*inputs*/, const VariableNameMap& /*outputs*/,
const AttributeMap& /*attrs*/)>;
using InferInplaceOpFN = std::function<InplacePair(bool /*use_cuda*/)>;
} // namespace framework
} // namespace paddle
......@@ -20,6 +20,36 @@
namespace paddle {
namespace imperative {
static void ClearNoNeedBufferInputs(OpBase* op) {
auto& inferer = op->Info().NoNeedBufferVarsInferer();
if (!inferer) return;
auto* ins = op->GetMutableInsMap();
const auto& no_need_buffer_slots =
inferer(*ins, op->GetOutsMap(), op->Attrs());
if (no_need_buffer_slots.empty()) return;
for (auto& slot : no_need_buffer_slots) {
auto iter = ins->find(slot);
if (iter == ins->end()) continue;
VLOG(2) << "Clear data buffer of " << slot << " in " << op->Type();
for (auto& each_var : iter->second) {
if (!each_var) continue;
auto& var = each_var->Var();
PADDLE_ENFORCE_EQ(var.IsType<framework::LoDTensor>(), true,
"Only support LoDTensor");
// TODO(zjl): support higher order derivatives
auto new_var = new VarBase(false, each_var->Name());
auto* new_tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto& old_tensor = var.Get<framework::LoDTensor>();
new_tensor->Resize(old_tensor.dims());
each_var.reset(new_var);
}
}
}
static std::vector<std::unique_ptr<OpBase>> CreateGradOpBases(
const OpBase* fw_op_base, const NameVarBaseMap& in,
const NameVarBaseMap& out) {
......@@ -151,6 +181,7 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
// this OpBase* is just used to manage op's life time
engine_->InsertOp(grad_op.get(), grad_op);
ClearNoNeedBufferInputs(grad_op.get());
}
}
......
......@@ -93,6 +93,7 @@ if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer)
# FIXME(typhoonzero): operator deps may not needed.
# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
......
......@@ -888,6 +888,7 @@ class PowOpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout());
}
};
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
......@@ -903,8 +904,7 @@ namespace plat = paddle::platform;
ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(), \
paddle::imperative::OpBase>, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type); \
ops::ActFwdInplaceInferer, void>::type); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
ops::ActivationGradOpInplaceInference);
......@@ -932,7 +932,7 @@ REGISTER_OPERATOR(
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut);
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::ReluDoubleGradMaker<paddle::framework::OpDesc>,
......@@ -962,7 +962,7 @@ REGISTER_OPERATOR(
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut);
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>,
......@@ -991,7 +991,7 @@ REGISTER_OPERATOR(
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut);
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>,
......@@ -1019,7 +1019,7 @@ REGISTER_OPERATOR(
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut);
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::SquareDoubleGradMaker<paddle::framework::OpDesc>,
......@@ -1049,7 +1049,7 @@ REGISTER_OPERATOR(
ops::PowGradOpMaker<paddle::framework::OpDesc>,
ops::PowGradOpMaker<paddle::imperative::OpBase>,
std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
::paddle::framework::SingleOpInplaceInToOut, void>::type);
ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
ops::ActivationGradOpInplaceInference);
......
......@@ -298,24 +298,14 @@ class AffineChannelNoNeedBufferVarsInference
public:
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
private:
inline bool HasOutput(const std::string& name) const {
auto& outputs = Outputs();
auto iter = outputs.find(name);
if (iter == outputs.end() || iter->second.empty()) {
return false;
const std::unordered_set<std::string>& operator()(
const framework::InferNoNeedBufferVarsContext& ctx) const final {
static const std::unordered_set<std::string> kX({"X"});
if (!ctx.HasOutput(framework::GradVarName("Scale")) &&
!ctx.HasOutput(framework::GradVarName("Bias"))) {
return kX;
} else {
return iter->second[0] != framework::kEmptyVarName;
}
}
public:
std::unordered_set<std::string> operator()() const override {
if (!HasOutput(framework::GradVarName("Scale")) &&
!HasOutput(framework::GradVarName("Bias"))) {
return {"X"};
} else {
return {};
return Empty();
}
}
};
......
......@@ -98,7 +98,7 @@ class ScaleGradMaker : public framework::SingleGradOpMaker<T> {
}
};
using ScaleOpInplace = framework::SingleOpInplaceInToOut;
DECLARE_INPLACE_OP_INFERER(ScaleOpInplace, {"X", "Out"});
} // namespace operators
} // namespace paddle
......
......@@ -27,7 +27,7 @@ class SequenceConvOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
......@@ -82,7 +82,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "The input(X) should not be null.");
......@@ -209,11 +209,13 @@ class SequenceConvGradNoNeedBufferVarsInference
public:
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
std::unordered_set<std::string> operator()() const override {
if (!boost::get<bool>(Attrs().at("paddingTrainable"))) {
return {"PaddingData"};
const std::unordered_set<std::string> &operator()(
const framework::InferNoNeedBufferVarsContext &ctx) const final {
static const std::unordered_set<std::string> kPaddingData({"PaddingData"});
if (!boost::get<bool>(ctx.GetAttr("paddingTrainable"))) {
return kPaddingData;
} else {
return {};
return Empty();
}
}
};
......
......@@ -221,20 +221,9 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpMaker<T> {
DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"});
class SoftmaxGradInplaceInferer final : public framework::InplaceOpInference {
public:
using framework::InplaceOpInference::InplaceOpInference;
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const final {
if (use_cuda) {
return {{"Out", framework::GradVarName("X")}};
} else {
// NOTE(zjl): AVX implementation of SoftmaxGrad does not support in-place
return {};
}
}
};
// NOTE(zjl): AVX implementation of SoftmaxGrad does not support in-place
DECLARE_CUDA_ONLY_INPLACE_OP_INFERER(SoftmaxGradInplaceInferer,
{"Out", framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册