提交 a36d2416 编写于 作者: Y Yu Yang 提交者: fengjiayi

Add no_grad_vars for grad_op_maker (#4770)

* Add no_grad_vars for grad_op_maker

* Add unittest

* Fix unittest

* Fix unittest

* Follow comment
上级 4cda9a36
...@@ -28,14 +28,15 @@ namespace paddle { ...@@ -28,14 +28,15 @@ namespace paddle {
namespace framework { namespace framework {
static inline std::unique_ptr<OperatorBase> CreateGradOp( static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op) { const OperatorBase& op,
const std::unordered_set<std::string>& no_grad_set) {
OpDescBind op_desc; OpDescBind op_desc;
op_desc.SetInputMap(op.Inputs()); op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs()); op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type()); op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs()); op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type()); auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.GradOpMaker()(op_desc); auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set);
std::vector<std::unique_ptr<OperatorBase>> grad_ops; std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size()); grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(), std::transform(grad_descs.begin(), grad_descs.end(),
...@@ -187,7 +188,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -187,7 +188,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net->InsertOp(pos.first + 1, std::move(pos.second)); net->InsertOp(pos.first + 1, std::move(pos.second));
} }
} else { } else {
std::unique_ptr<OperatorBase> grad_op(CreateGradOp(forwardOp)); std::unique_ptr<OperatorBase> grad_op(
CreateGradOp(forwardOp, no_grad_names));
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) { const std::string& grad_input) {
...@@ -272,7 +274,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -272,7 +274,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const std::unique_ptr<OpDescBind>& op_desc, const std::unique_ptr<OpDescBind>& op_desc,
std::unordered_set<std::string>& no_grad_vars) { std::unordered_set<std::string>& no_grad_vars) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs; std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculat. // All input gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& inputs = op_desc->InputArgumentNames(); const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
if (AllGradInSet(inputs, no_grad_vars)) { if (AllGradInSet(inputs, no_grad_vars)) {
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
...@@ -286,7 +288,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -286,7 +288,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get()); grad_op_descs = OpInfoMap::Instance()
.Get(op_desc->Type())
.GradOpMaker()(*op_desc, no_grad_vars);
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops; std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) { for (auto& desc : grad_op_descs) {
...@@ -301,11 +305,6 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -301,11 +305,6 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
pending_fill_zeros_ops.push_back(std::move(fill_zeros_op)); pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
} }
} }
for (const std::string& out_name : desc->OutputArgumentNames()) {
if (no_grad_vars.count(out_name)) {
desc->Rename(out_name, kEmptyVarName);
}
}
} }
for (auto& p : pending_fill_zeros_ops) { for (auto& p : pending_fill_zeros_ops) {
......
...@@ -169,6 +169,45 @@ class MultInOutOpMaker : public OpProtoAndCheckerMaker { ...@@ -169,6 +169,45 @@ class MultInOutOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class MinusGradOpDescMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const override {
std::vector<std::unique_ptr<OpDescBind>> retv;
auto x_g = InputGrad("X");
if (!x_g.empty()) {
auto *op_desc = new OpDescBind();
op_desc->SetType("scale");
op_desc->SetInput("X", OutputGrad("Out"));
op_desc->SetOutput("Out", x_g);
op_desc->SetAttr("scale", 1.0f);
retv.emplace_back(op_desc);
}
auto y_g = InputGrad("Y");
if (!y_g.empty()) {
auto *op_desc = new OpDescBind();
op_desc->SetType("scale");
op_desc->SetInput("X", OutputGrad("Out"));
op_desc->SetOutput("Out", y_g);
op_desc->SetAttr("scale", -1.0f);
retv.emplace_back(op_desc);
}
return retv;
}
};
class MinusOpMaker : public OpProtoAndCheckerMaker {
public:
MinusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
AddInput("Y", "");
AddOutput("Out", "");
AddComment("minus for unittest");
}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -187,6 +226,7 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); ...@@ -187,6 +226,7 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP); f::NOP);
REGISTER_OP(mult_in_out, f::NOP, f::MultInOutOpMaker, mult_in_out_grad, f::NOP); REGISTER_OP(mult_in_out, f::NOP, f::MultInOutOpMaker, mult_in_out_grad, f::NOP);
REGISTER_OPERATOR(minus, f::NOP, f::MinusOpMaker, f::MinusGradOpDescMaker);
TEST(Backward, simple_op_not_need_grad) { TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp( auto fwd = f::OpRegistry::CreateOp(
...@@ -395,12 +435,13 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -395,12 +435,13 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
2UL /* external input number */ 2UL /* external input number */
+ 1UL /* external output number*/ + 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/ + 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/); + 2UL /* internal variable number*/
);
EXPECT_EQ(grad_fc.Outputs(all).size(), EXPECT_EQ(grad_fc.Outputs(all).size(),
2UL /* input number of mul*/ 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add + 2UL /* input number of rowwise_add*/
*/ + 1UL /* input number of sigmod */
+ 1UL /* input number of sigmod */); - 1UL /* out2 is not needed*/);
EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL);
...@@ -580,8 +621,7 @@ TEST(Backward, intermedia_var_no_grad) { ...@@ -580,8 +621,7 @@ TEST(Backward, intermedia_var_no_grad) {
std::vector<std::string>({f::GradVarName("out4")})); std::vector<std::string>({f::GradVarName("out4")}));
EXPECT_EQ(grad_op4->Output(f::GradVarName("X")), EXPECT_EQ(grad_op4->Output(f::GradVarName("X")),
std::vector<std::string>({f::GradVarName("out1")})); std::vector<std::string>({f::GradVarName("out1")}));
EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")), EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")), std::vector<std::string>());
std::vector<std::string>({f::kEmptyVarName}));
} }
TEST(Backward, var_no_grad) { TEST(Backward, var_no_grad) {
...@@ -619,8 +659,7 @@ TEST(Backward, var_no_grad) { ...@@ -619,8 +659,7 @@ TEST(Backward, var_no_grad) {
std::vector<std::string>({f::GradVarName("z2")})); std::vector<std::string>({f::GradVarName("z2")}));
EXPECT_EQ(grad_op2->Output(f::GradVarName("X")), EXPECT_EQ(grad_op2->Output(f::GradVarName("X")),
std::vector<std::string>({f::GradVarName("y1")})); std::vector<std::string>({f::GradVarName("y1")}));
EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), std::vector<std::string>());
std::vector<std::string>({f::kEmptyVarName}));
f::OpDescBind *fill_zero_op = block->AllOps()[3]; f::OpDescBind *fill_zero_op = block->AllOps()[3];
ASSERT_EQ(fill_zero_op->Type(), "fill_zeros_like"); ASSERT_EQ(fill_zero_op->Type(), "fill_zeros_like");
...@@ -719,3 +758,18 @@ TEST(Backward, shared_var) { ...@@ -719,3 +758,18 @@ TEST(Backward, shared_var) {
EXPECT_EQ(grad_op1->Output(f::GradVarName("b")), EXPECT_EQ(grad_op1->Output(f::GradVarName("b")),
std::vector<std::string>({f::GradVarName("b1")})); std::vector<std::string>({f::GradVarName("b1")}));
} }
TEST(Backward, half_backward) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0);
auto *op1 = block->AppendOp();
op1->SetType("minus");
op1->SetInput("X", {"a"});
op1->SetInput("Y", {"b"});
op1->SetOutput("Out", {"out"});
AppendBackward(program, {"b"});
auto ops = block->AllOps();
ASSERT_EQ(2UL, ops.size());
}
\ No newline at end of file
...@@ -97,8 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -97,8 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> { struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = [](const OpDescBind& fwd_op) { info->grad_op_maker_ = [](
T maker(fwd_op); const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set) {
T maker(fwd_op, no_grad_set);
return maker(); return maker();
}; };
} }
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <unordered_set>
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
...@@ -21,27 +23,44 @@ namespace framework { ...@@ -21,27 +23,44 @@ namespace framework {
class GradOpDescMakerBase { class GradOpDescMakerBase {
public: public:
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {} explicit GradOpDescMakerBase(
const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set)
: fwd_op_(fwd_op), no_grad_set_(no_grad_set) {}
virtual ~GradOpDescMakerBase() = default; virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0; virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
protected: protected:
static std::vector<std::string> ToGradNames( std::vector<std::string> InputGrad(const std::string& name,
const std::vector<std::string>& var_names) { bool drop_empty_grad = true) const {
std::vector<std::string> ret_val; std::vector<std::string> ret_val;
auto var_names = this->Input(name);
ret_val.reserve(var_names.size()); ret_val.reserve(var_names.size());
std::transform(var_names.begin(), var_names.end(), std::transform(
std::back_inserter(ret_val), GradVarName); var_names.begin(), var_names.end(), std::back_inserter(ret_val),
[this](const std::string& fwd_var_name) -> std::string {
auto g_name = GradVarName(fwd_var_name);
return no_grad_set_.count(g_name) == 0 ? g_name : kEmptyVarName;
});
if (!drop_empty_grad) {
return ret_val; return ret_val;
} }
std::vector<std::string> dropped_ret_val;
std::vector<std::string> InputGrad(const std::string& name) const { dropped_ret_val.reserve(ret_val.size());
return ToGradNames(fwd_op_.Input(name)); std::copy_if(ret_val.begin(), ret_val.end(),
std::back_inserter(dropped_ret_val),
[](const std::string& str) { return str != kEmptyVarName; });
return dropped_ret_val;
} }
std::vector<std::string> OutputGrad(const std::string& name) const { std::vector<std::string> OutputGrad(const std::string& name) const {
return ToGradNames(fwd_op_.Output(name)); std::vector<std::string> ret_val;
auto onames = this->Output(name);
ret_val.reserve(onames.size());
std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val),
GradVarName);
return ret_val;
} }
std::vector<std::string> InputNames() const { std::vector<std::string> InputNames() const {
...@@ -75,6 +94,7 @@ class GradOpDescMakerBase { ...@@ -75,6 +94,7 @@ class GradOpDescMakerBase {
private: private:
const OpDescBind& fwd_op_; const OpDescBind& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_;
}; };
class SingleGradOpDescMaker : public GradOpDescMakerBase { class SingleGradOpDescMaker : public GradOpDescMakerBase {
...@@ -91,6 +111,7 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase { ...@@ -91,6 +111,7 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
virtual std::unique_ptr<OpDescBind> Apply() const = 0; virtual std::unique_ptr<OpDescBind> Apply() const = 0;
}; };
template <bool DropEmptyIG = true>
class DefaultGradOpDescMaker : public SingleGradOpDescMaker { class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
public: public:
using SingleGradOpDescMaker::SingleGradOpDescMaker; using SingleGradOpDescMaker::SingleGradOpDescMaker;
...@@ -102,7 +123,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { ...@@ -102,7 +123,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
for (auto& input_param : this->InputNames()) { for (auto& input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param)); grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param)); grad->SetOutput(GradVarName(input_param),
this->InputGrad(input_param, DropEmptyIG));
} }
for (auto& output_param : this->OutputNames()) { for (auto& output_param : this->OutputNames()) {
......
...@@ -59,11 +59,5 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) { ...@@ -59,11 +59,5 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
op_desc.GetAttrMap()); op_desc.GetAttrMap());
} }
std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs(
OpDescBind* op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc->Type());
return info.grad_op_maker_(*op_desc);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -79,9 +79,6 @@ class OpRegistry { ...@@ -79,9 +79,6 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::vector<std::unique_ptr<OpDescBind>> CreateGradOpDescs(
OpDescBind* op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
}; };
...@@ -164,8 +161,9 @@ class OpKernelRegistrar : public Registrar { ...@@ -164,8 +161,9 @@ class OpKernelRegistrar : public Registrar {
grad_op_class) \ grad_op_class) \
REGISTER_OPERATOR(grad_op_type, grad_op_class); \ REGISTER_OPERATOR(grad_op_type, grad_op_class); \
class _GradOpDescMaker_##grad_op_type##_ \ class _GradOpDescMaker_##grad_op_type##_ \
: public ::paddle::framework::DefaultGradOpDescMaker { \ : public ::paddle::framework::DefaultGradOpDescMaker<true> { \
using ::paddle::framework::DefaultGradOpDescMaker::DefaultGradOpDescMaker; \ using ::paddle::framework::DefaultGradOpDescMaker< \
true>::DefaultGradOpDescMaker; \
\ \
protected: \ protected: \
virtual std::string GradOpType() const { return #grad_op_type; } \ virtual std::string GradOpType() const { return #grad_op_type; } \
......
...@@ -36,8 +36,8 @@ using OpCreator = std::function<OperatorBase*( ...@@ -36,8 +36,8 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/, const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN = using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>(
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>; const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -115,8 +115,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -115,8 +115,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker,
ops::MultiplexGradOp); paddle::framework::DefaultGradOpDescMaker<false>);
REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
multiplex, ops::MultiplexCPUKernel<paddle::platform::CPUPlace, float>); multiplex, ops::MultiplexCPUKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册