提交 80de7e5e 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #3460 from jacquesqiao/public_to_protected

change operator public to protected
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace framework { namespace framework {
template <typename Map, typename T> template <typename Map, typename T>
static void ForEachVarName(Map& names, T callback) { static void ForEachVarName(const Map& names, T callback) {
for (auto& name : names) { for (auto& name : names) {
for (auto& n : name.second) { for (auto& n : name.second) {
if (callback(n)) return; if (callback(n)) return;
...@@ -44,7 +44,7 @@ static bool AllInSet( ...@@ -44,7 +44,7 @@ static bool AllInSet(
static std::shared_ptr<OperatorBase> NOP() { static std::shared_ptr<OperatorBase> NOP() {
auto net_op = std::make_shared<operators::NetOp>(); auto net_op = std::make_shared<operators::NetOp>();
net_op->type_ = "@NOP@"; net_op->SetType("@NOP@");
net_op->CompleteAddOp(); net_op->CompleteAddOp();
return net_op; return net_op;
} }
...@@ -70,8 +70,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -70,8 +70,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate, // If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take // just return an NOP. Not return null ptr because NOP does not take
// much time for calculation, but it is useful for simplifying logic. // too much time for calculation, but it is useful for simplifying logic.
if (AllInSet(forwardOp.inputs_ /*names*/, kGradVarSuffix /*suffix*/, if (AllInSet(forwardOp.Inputs() /*names*/, kGradVarSuffix /*suffix*/,
no_grad_names /*set*/)) { no_grad_names /*set*/)) {
return NOP(); return NOP();
} }
...@@ -79,9 +79,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -79,9 +79,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into // Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP. // `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_ /*names*/, kGradVarSuffix /*suffix*/, if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/,
no_grad_names /*set*/)) { no_grad_names /*set*/)) {
ForEachVarName(forwardOp.inputs_, ForEachVarName(forwardOp.Inputs(),
[&no_grad_names](const std::string& name) -> bool { [&no_grad_names](const std::string& name) -> bool {
no_grad_names.insert(GradVarName(name)); no_grad_names.insert(GradVarName(name));
return false; return false;
...@@ -107,7 +107,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -107,7 +107,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
auto fwd = *it; auto fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd); net->AddOp(bwd);
ForEachVarName(bwd->outputs_, ForEachVarName(bwd->Outputs(),
[&dup_output_ops, local_op_id](const std::string& out) { [&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id); dup_output_ops[out].emplace_back(local_op_id);
return false; return false;
...@@ -154,13 +154,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -154,13 +154,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
} else { } else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
ForEachVarName(grad_op->inputs_, [&no_grad_names, ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net,
&net](std::string& grad_input) { grad_op](const std::string& grad_input) {
if (no_grad_names.count(grad_input)) { if (no_grad_names.count(grad_input)) {
// +1 for \0 // +1 for \0
std::string prefix = grad_input.substr( std::string prefix = grad_input.substr(
0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); 0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
grad_input = prefix + kZeroVarSuffix; grad_op->Rename(grad_input, prefix + kZeroVarSuffix);
// If part of input gradient of that operator is not calculated, fill // If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient. // zero variables to that input gradient.
...@@ -170,10 +170,10 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -170,10 +170,10 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
return false; return false;
}); });
ForEachVarName(grad_op->outputs_, ForEachVarName(grad_op->Outputs(),
[&no_grad_names](std::string& grad_output) { [&no_grad_names, &grad_op](const std::string& grad_output) {
if (no_grad_names.count(grad_output)) { if (no_grad_names.count(grad_output)) {
grad_output = kEmptyVarName; grad_op->Rename(grad_output, kEmptyVarName);
} }
return false; return false;
}); });
...@@ -183,7 +183,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -183,7 +183,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
} }
net->AddOp(grad_op); net->AddOp(grad_op);
} }
net->type_ = "@GENERATED_BACKWARD@"; net->SetType("@GENERATED_BACKWARD@");
net->CompleteAddOp(); net->CompleteAddOp();
return net; return net;
} // namespace framework } // namespace framework
......
...@@ -164,8 +164,8 @@ TEST(Backward, simple_op_grad) { ...@@ -164,8 +164,8 @@ TEST(Backward, simple_op_grad) {
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd); auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ(1UL, gop->Inputs().size());
ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("rowwise_add_grad", gop->Type());
ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X"))); ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X")));
ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b"))); ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b")));
} }
...@@ -201,13 +201,13 @@ TEST(Backward, net_fc_backward_normal) { ...@@ -201,13 +201,13 @@ TEST(Backward, net_fc_backward_normal) {
ASSERT_EQ(3UL, net->ops_.size()); ASSERT_EQ(3UL, net->ops_.size());
f::OperatorBase &d_sigmoid = *net->ops_[0]; f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); ASSERT_EQ("sigmoid_grad", d_sigmoid.Type());
f::OperatorBase &d_add = *net->ops_[1]; f::OperatorBase &d_add = *net->ops_[1];
ASSERT_EQ("rowwise_add_grad", d_add.type_); ASSERT_EQ("rowwise_add_grad", d_add.Type());
f::OperatorBase &d_mul = *net->ops_[2]; f::OperatorBase &d_mul = *net->ops_[2];
ASSERT_EQ("mul_grad", d_mul.type_); ASSERT_EQ("mul_grad", d_mul.Type());
} }
TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_fc_backward_not_have_b) {
...@@ -227,10 +227,10 @@ TEST(Backward, net_fc_backward_not_have_b) { ...@@ -227,10 +227,10 @@ TEST(Backward, net_fc_backward_not_have_b) {
ASSERT_EQ(2UL, net->ops_.size()); ASSERT_EQ(2UL, net->ops_.size());
f::OperatorBase &d_sigmoid = *net->ops_[0]; f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); ASSERT_EQ("sigmoid_grad", d_sigmoid.Type());
f::OperatorBase &d_mul = *net->ops_[1]; f::OperatorBase &d_mul = *net->ops_[1];
ASSERT_EQ("mul_grad", d_mul.type_); ASSERT_EQ("mul_grad", d_mul.Type());
} }
TEST(Backward, net_input_of_network_not_need_grad) { TEST(Backward, net_input_of_network_not_need_grad) {
...@@ -284,7 +284,7 @@ TEST(Backward, net_shared_weight) { ...@@ -284,7 +284,7 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get()); auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->type_); ASSERT_EQ("add", bwd_net->ops_[2]->Type());
} }
TEST(Backward, op_register_grad_not_for_network) { TEST(Backward, op_register_grad_not_for_network) {
...@@ -325,15 +325,15 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -325,15 +325,15 @@ TEST(Backward, op_part_of_output_are_not_need) {
ASSERT_EQ(net->ops_.size(), 2UL); ASSERT_EQ(net->ops_.size(), 2UL);
auto &fill_zero = *net->ops_[0]; auto &fill_zero = *net->ops_[0];
ASSERT_EQ("fill_zeros_like", fill_zero.type_); ASSERT_EQ("fill_zeros_like", fill_zero.Type());
ASSERT_EQ(1UL, fill_zero.Inputs("Src").size()); ASSERT_EQ(1UL, fill_zero.Inputs("Src").size());
ASSERT_EQ("Z", fill_zero.Input("Src")); ASSERT_EQ("Z", fill_zero.Input("Src"));
ASSERT_EQ(1UL, fill_zero.Outputs("Dst").size()); ASSERT_EQ(1UL, fill_zero.Outputs("Dst").size());
ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Dst")); ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Dst"));
auto &d_many_out = *net->ops_[1]; auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.type_); ASSERT_EQ("many_output_op_grad", d_many_out.Type());
ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.Inputs().size()); // I/O/OG
ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix,
d_many_out.Input(f::GradVarName("z"))); d_many_out.Input(f::GradVarName("z")));
ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y"))); ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y")));
...@@ -345,9 +345,9 @@ TEST(Backward, op_part_of_input_are_not_need) { ...@@ -345,9 +345,9 @@ TEST(Backward, op_part_of_input_are_not_need) {
{{"Out", {"out"}}}, {}); {{"Out", {"out"}}}, {});
auto backward = f::Backward(*fwd, {"a"}); auto backward = f::Backward(*fwd, {"a"});
auto &grad_mul = *backward; auto &grad_mul = *backward;
ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.Type(), "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.Inputs().size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL); ASSERT_EQ(grad_mul.Outputs().size(), 2UL);
ASSERT_EQ(grad_mul.Output(f::GradVarName("X")), f::kEmptyVarName); ASSERT_EQ(grad_mul.Output(f::GradVarName("X")), f::kEmptyVarName);
ASSERT_EQ(grad_mul.Output(f::GradVarName("Y")), f::GradVarName("b")); ASSERT_EQ(grad_mul.Output(f::GradVarName("Y")), f::GradVarName("b"));
ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out")); ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out"));
...@@ -385,18 +385,18 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -385,18 +385,18 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
auto &grad_fc = *bwd_net->ops_[0]; auto &grad_fc = *bwd_net->ops_[0];
const char *all = paddle::operators::NetOp::kAll; const char *all = paddle::operators::NetOp::kAll;
EXPECT_EQ(grad_fc.inputs_[all].size(), EXPECT_EQ(grad_fc.Inputs(all).size(),
2UL /* external input number */ 2UL /* external input number */
+ 1UL /* external output number*/ + 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/ + 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/); + 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_[all].size(), EXPECT_EQ(grad_fc.Outputs(all).size(),
2UL /* input number of mul*/ 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add + 2UL /* input number of rowwise_add
*/ */
+ 1UL /* input number of sigmod */); + 1UL /* input number of sigmod */);
EXPECT_EQ(bwd_net->ops_[1]->inputs_[all].size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->outputs_[all].size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_[all].size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_[all].size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL);
} }
...@@ -13,7 +13,6 @@ express or implied. See the License for the specific language governing ...@@ -13,7 +13,6 @@ express or implied. See the License for the specific language governing
permissions and limitations under the License. */ permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
...@@ -23,9 +22,9 @@ enum class OpArgType { IN, OUT }; ...@@ -23,9 +22,9 @@ enum class OpArgType { IN, OUT };
static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
bool is_grad, OperatorBase::VarNameMap* vars) { bool is_grad, OperatorBase::VarNameMap* vars) {
const auto& src_inout = const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
auto& dst_inout = *vars; auto& dst_inout = *vars;
const OpProto* proto = OpRegistry::op_info_map().at(src_op->type_).proto_; const OpProto* proto = OpRegistry::op_info_map().at(src_op->Type()).proto_;
const auto& src_arg_list = const auto& src_arg_list =
src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
...@@ -41,14 +40,14 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, ...@@ -41,14 +40,14 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
auto it = OpRegistry::op_info_map().find(op->type_); auto it = OpRegistry::op_info_map().find(op->Type());
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", op->type_); "'%s' has not been registered.", op->Type());
PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.", PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
op->type_); op->Type());
std::string grad_op_type = it->second.grad_op_type_; std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type_); op->Type());
OperatorBase::VarNameMap inputs; OperatorBase::VarNameMap inputs;
OperatorBase::VarNameMap outputs; OperatorBase::VarNameMap outputs;
...@@ -60,7 +59,7 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { ...@@ -60,7 +59,7 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
it = OpRegistry::op_info_map().find(grad_op_type); it = OpRegistry::op_info_map().find(grad_op_type);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", grad_op_type); "'%s' has not been registered.", grad_op_type);
return it->second.creator_(grad_op_type, inputs, outputs, op->attrs_); return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs());
} }
} // namespace framework } // namespace framework
......
...@@ -44,8 +44,8 @@ TEST(GradOpBuilder, AddTwo) { ...@@ -44,8 +44,8 @@ TEST(GradOpBuilder, AddTwo) {
"add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {})); "add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_add_op = std::shared_ptr<f::OperatorBase> grad_add_op =
f::OpRegistry::CreateGradOp(*add_op); f::OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(grad_add_op->inputs_.size(), 4UL); EXPECT_EQ(grad_add_op->Inputs().size(), 4UL);
EXPECT_EQ(grad_add_op->outputs_.size(), 2UL); EXPECT_EQ(grad_add_op->Outputs().size(), 2UL);
EXPECT_EQ(grad_add_op->Input("X"), "x"); EXPECT_EQ(grad_add_op->Input("X"), "x");
EXPECT_EQ(grad_add_op->Input("Y"), "y"); EXPECT_EQ(grad_add_op->Input("Y"), "y");
EXPECT_EQ(grad_add_op->Input("Out"), "out"); EXPECT_EQ(grad_add_op->Input("Out"), "out");
...@@ -66,7 +66,7 @@ TEST(GradOpBuilder, MutiInOut) { ...@@ -66,7 +66,7 @@ TEST(GradOpBuilder, MutiInOut) {
std::shared_ptr<f::OperatorBase> grad_test_op = std::shared_ptr<f::OperatorBase> grad_test_op =
f::OpRegistry::CreateGradOp(*test_op); f::OpRegistry::CreateGradOp(*test_op);
ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL); ASSERT_EQ(grad_test_op->Inputs().size(), 3UL + 2UL + 2UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In2_mult"), EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
std::vector<std::string>({"in2_1", "in2_2", "in2_3"})); std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
...@@ -80,7 +80,7 @@ TEST(GradOpBuilder, MutiInOut) { ...@@ -80,7 +80,7 @@ TEST(GradOpBuilder, MutiInOut) {
std::vector<std::string>( std::vector<std::string>(
{f::GradVarName("out2_1"), f::GradVarName("out2_2")})); {f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
ASSERT_EQ(grad_test_op->outputs_.size(), 3UL); ASSERT_EQ(grad_test_op->Outputs().size(), 3UL);
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
std::vector<std::string>({f::GradVarName("in2_1"), std::vector<std::string>({f::GradVarName("in2_1"),
...@@ -99,7 +99,7 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { ...@@ -99,7 +99,7 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
f::OpRegistry::CreateGradOp(*test_op); f::OpRegistry::CreateGradOp(*test_op);
// 'In2' and 'Out2' are ignored in gradient calculating // 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ(grad_test_op->inputs_.size(), 2UL + 1UL + 2UL); ASSERT_EQ(grad_test_op->Inputs().size(), 2UL + 1UL + 2UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In3_mult"), EXPECT_EQ(grad_test_op->Inputs("In3_mult"),
std::vector<std::string>({"in3_1", "in3_2"})); std::vector<std::string>({"in3_1", "in3_2"}));
...@@ -111,7 +111,7 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { ...@@ -111,7 +111,7 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")), EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")),
f::GradVarName("out2")); f::GradVarName("out2"));
ASSERT_EQ(grad_test_op->outputs_.size(), 3UL); ASSERT_EQ(grad_test_op->Outputs().size(), 3UL);
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
std::vector<std::string>( std::vector<std::string>(
......
...@@ -97,6 +97,8 @@ class OperatorBase { ...@@ -97,6 +97,8 @@ class OperatorBase {
/// rename inputs outputs name /// rename inputs outputs name
void Rename(const std::string& old_name, const std::string& new_name); void Rename(const std::string& old_name, const std::string& new_name);
const VarNameMap& Inputs() const { return inputs_; }
const VarNameMap& Outputs() const { return outputs_; }
//! Get a input with argument's name described in `op_proto` //! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables. //! Get a input which has multiple variables.
...@@ -110,10 +112,11 @@ class OperatorBase { ...@@ -110,10 +112,11 @@ class OperatorBase {
virtual std::vector<std::string> OutputVars(bool has_intermediate) const; virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
std::string Type() const { return type_; } const std::string& Type() const { return type_; }
void SetType(const std::string& type) { type_ = type; }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
public: protected:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs) // I (Inputs)
......
...@@ -53,15 +53,15 @@ void ExposeOperator(ClassType &m) { ...@@ -53,15 +53,15 @@ void ExposeOperator(ClassType &m) {
.def("run", &ClassType::type::Run) .def("run", &ClassType::type::Run)
.def("type", .def("type",
[](const typename ClassType::type &op) -> std::string { [](const typename ClassType::type &op) -> std::string {
return op.type_; return op.Type();
}) })
.def("outputs", .def("outputs",
[](const typename ClassType::type &op) [](const typename ClassType::type &op)
-> std::map<std::string, std::vector<std::string>> { -> std::map<std::string, std::vector<std::string>> {
return op.outputs_; return op.Outputs();
}) })
.def("inputs", .def("inputs",
[](const typename ClassType::type &op) { return op.inputs_; }) [](const typename ClassType::type &op) { return op.Inputs(); })
.def("__str__", &ClassType::type::DebugString) .def("__str__", &ClassType::type::DebugString)
.def("no_intermediate_outputs", .def("no_intermediate_outputs",
[](const typename ClassType::type &op) { [](const typename ClassType::type &op) {
...@@ -232,7 +232,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -232,7 +232,7 @@ All parameter, weight, gradient are variables in Paddle.
net.def_static("create", net.def_static("create",
[]() -> std::shared_ptr<operators::NetOp> { []() -> std::shared_ptr<operators::NetOp> {
auto retv = std::make_shared<operators::NetOp>(); auto retv = std::make_shared<operators::NetOp>();
retv->type_ = "plain_net"; retv->SetType("plain_net");
return retv; return retv;
}) })
.def("add_op", &operators::NetOp::AddOp) .def("add_op", &operators::NetOp::AddOp)
......
...@@ -29,7 +29,7 @@ void NetOp::CompleteAddOp(bool calc) { ...@@ -29,7 +29,7 @@ void NetOp::CompleteAddOp(bool calc) {
std::set<std::string> input_set; std::set<std::string> input_set;
std::set<std::string> output_set; std::set<std::string> output_set;
for (auto& op : ops_) { for (auto& op : ops_) {
for (auto& ipt : op->inputs_) { for (auto& ipt : op->Inputs()) {
for (auto& var_name : ipt.second) { for (auto& var_name : ipt.second) {
if (!Contains(output_set, var_name)) { // Not other op's output if (!Contains(output_set, var_name)) { // Not other op's output
input_set.insert(var_name); input_set.insert(var_name);
...@@ -39,7 +39,7 @@ void NetOp::CompleteAddOp(bool calc) { ...@@ -39,7 +39,7 @@ void NetOp::CompleteAddOp(bool calc) {
} }
} }
for (auto& opt : op->outputs_) { for (auto& opt : op->Outputs()) {
for (auto& var_name : opt.second) { for (auto& var_name : opt.second) {
output_set.insert(var_name); output_set.insert(var_name);
} }
......
...@@ -49,8 +49,8 @@ TEST(OpKernel, all) { ...@@ -49,8 +49,8 @@ TEST(OpKernel, all) {
net->CompleteAddOp(); net->CompleteAddOp();
AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"},
net->inputs_.at(NetOp::kAll)); net->Inputs(NetOp::kAll));
AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at(NetOp::kAll)); AssertSameVectorWithoutOrder({"y", "z"}, net->Outputs(NetOp::kAll));
auto final_outs = net->OutputVars(false); auto final_outs = net->OutputVars(false);
......
...@@ -82,14 +82,14 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -82,14 +82,14 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope", PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope",
arg_->step_net); arg_->step_net);
auto net_op = net_var->GetMutable<NetOp>(); auto net_op = net_var->GetMutable<NetOp>();
PADDLE_ENFORCE(!net_op->outputs_.empty(), "net_op has no outputs"); PADDLE_ENFORCE(!net_op->Outputs().empty(), "net_op has no outputs");
if (seq_len_ > step_scopes->size()) { if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope(); auto& step_scope = scope.NewScope();
// create step net's temp inputs // create step net's temp inputs
for (auto& input : net_op->inputs_) { for (auto& input : net_op->Inputs()) {
// the weight are located in parent scope // the weight are located in parent scope
for (auto& var_name : input.second) { for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) { if (!step_scope.FindVar(var_name)) {
...@@ -98,7 +98,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -98,7 +98,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
} }
} }
// create stepnet's outputs // create stepnet's outputs
for (const auto& output : net_op->outputs_) { for (const auto& output : net_op->Outputs()) {
for (auto& var_name : output.second) { for (auto& var_name : output.second) {
step_scope.NewVar(var_name); step_scope.NewVar(var_name);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册