From 4a604c2651ea34b5befa9ac45028ddbae7733ad0 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 14 Aug 2017 12:54:53 +0800 Subject: [PATCH] Polish Our code by YuYang's review --- paddle/framework/backward_test.cc | 26 +++++---- paddle/framework/ddim.cc | 7 --- paddle/framework/ddim.h | 2 - paddle/framework/grad_op_builder.cc | 3 - paddle/framework/grad_op_builder_test.cc | 12 ++-- paddle/framework/op_registry.h | 33 +++++------ paddle/framework/op_registry_test.cc | 53 ++++++++--------- paddle/framework/operator.cc | 57 ++++++++++++++----- paddle/framework/operator.h | 37 ++---------- paddle/framework/operator_test.cc | 45 ++++++++------- paddle/operators/mean_op.cc | 2 +- paddle/operators/recurrent_op.cc | 6 +- paddle/operators/recurrent_op_test.cc | 2 - .../v2/framework/tests/test_add_two_op.py | 8 --- 14 files changed, 138 insertions(+), 155 deletions(-) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index dc09f095b9..d6ba1f7d63 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input X of Add").IgnoreGradient(); - AddInput("b", "Bias of Add").IgnoreGradient(); - AddOutput("Out", "Out of Add").IgnoreGradient(); + AddInput("X", "Input X of Add").NoGradient(); + AddInput("b", "Bias of Add").NoGradient(); + AddOutput("Out", "Out of Add").NoGradient(); AddComment("Add Op"); } }; @@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker { AddInput("X", "x"); AddInput("W", "w"); AddInput("b", "b"); - AddOutput("mul_result", "").SetTemporary(); - AddOutput("add_result", "").SetTemporary(); + AddOutput("mul_result", "").SetIntermediate(); + AddOutput("add_result", "").SetIntermediate(); AddOutput("Out", ""); AddComment(""); } @@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker { public: AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "x").SetMultiple(); + AddInput("X", "x").SetDuplicable(); AddOutput("Y", "y"); AddComment(""); } @@ -392,18 +392,20 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { auto bwd_net = static_cast(backward.get()); ASSERT_EQ(bwd_net->ops_.size(), 3UL); auto &grad_fc = *bwd_net->ops_[0]; - EXPECT_EQ(grad_fc.inputs_["all"].size(), + + const char *all = paddle::operators::NetOp::kAll; + EXPECT_EQ(grad_fc.inputs_[all].size(), 2UL /* external input number */ + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ + 2U /* internal variable number*/); - EXPECT_EQ(grad_fc.outputs_["all"].size(), + EXPECT_EQ(grad_fc.outputs_[all].size(), 2UL /* input number of mul*/ + 2UL /* input number of rowwise_add */ + 1UL /* input number of sigmod */); - EXPECT_EQ(bwd_net->ops_[1]->inputs_["all"].size(), 0UL); - EXPECT_EQ(bwd_net->ops_[1]->outputs_["all"].size(), 0UL); - EXPECT_EQ(bwd_net->ops_[2]->inputs_["all"].size(), 0UL); - EXPECT_EQ(bwd_net->ops_[2]->outputs_["all"].size(), 0UL); + EXPECT_EQ(bwd_net->ops_[1]->inputs_[all].size(), 0UL); + EXPECT_EQ(bwd_net->ops_[1]->outputs_[all].size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->inputs_[all].size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->outputs_[all].size(), 0UL); } diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 0b76a4fdb7..cfd3e8dfde 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -283,12 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { DDim::DDim(std::initializer_list init_list) { *this = make_ddim(init_list); } - -std::string DDim::DebugString() const { - std::ostringstream ss; - ss << *this; - return ss.str(); -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 1627bcb269..95f294b627 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -72,8 +72,6 @@ struct DDim { DDim operator*(DDim d) const; ssize_t size() const; - - std::string DebugString() const; }; /** diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 35db0cf716..7319fcc88c 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -18,9 +18,6 @@ permissions and limitations under the License. */ namespace paddle { namespace framework { - -class OpRegistry; - enum class OpArgType { IN, OUT }; static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index c95583c0af..210e07942b 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker { MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").SetMultiple(); + AddInput("In2_mult", "a multiple input").SetDuplicable(); AddInput("In3", "another single input"); AddOutput("Out1", "a single output"); - AddOutput("Out2_mult", "a multiple output").SetMultiple(); + AddOutput("Out2_mult", "a multiple output").SetDuplicable(); AddComment("test op with multiple inputs and outputs"); } }; @@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").SetMultiple().IgnoreGradient(); - AddInput("In3_mult", "another multiple input").SetMultiple(); - AddOutput("Out1_mult", "a multiple output").SetMultiple(); - AddOutput("Out2", "a single output").IgnoreGradient(); + AddInput("In2_mult", "a multiple input").SetDuplicable().NoGradient(); + AddInput("In3_mult", "another multiple input").SetDuplicable(); + AddOutput("Out1_mult", "a multiple output").SetDuplicable(); + AddOutput("Out2", "a single output").NoGradient(); AddComment("op with inputs and outputs ignored in gradient calculating"); } }; diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index f6b71a4efd..d840c1c4e0 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -47,17 +47,17 @@ class OpProtoAndCheckerMaker { struct VariableBuilder { OpProto::Var* var_; - VariableBuilder& SetMultiple() { + VariableBuilder& SetDuplicable() { var_->set_duplicable(true); return *this; } - VariableBuilder& SetTemporary() { + VariableBuilder& SetIntermediate() { var_->set_intermediate(true); return *this; } - VariableBuilder& IgnoreGradient() { + VariableBuilder& NoGradient() { var_->set_no_gradient(true); return *this; } @@ -118,7 +118,7 @@ class OpProtoAndCheckerMaker { class OpRegistry { using OpCreator = std::function; - using VarNameMap = std::map>; + using VarNameMap = OperatorBase::VarNameMap; public: template @@ -164,25 +164,22 @@ class OpRegistry { return std::shared_ptr(op); } - static std::shared_ptr CreateOp(const OpDesc& op_desc) { - VarNameMap inputs; - for (auto& input : op_desc.inputs()) { - auto& var_names = inputs[input.parameter()]; - auto& var_names_in_proto = input.arguments(); - var_names.reserve(static_cast(var_names_in_proto.size())); - std::copy(var_names_in_proto.begin(), var_names_in_proto.end(), - std::back_inserter(var_names)); - } - - VarNameMap outputs; - for (auto& output : op_desc.outputs()) { - auto& var_names = outputs[output.parameter()]; - auto& var_names_in_proto = output.arguments(); + static VarNameMap ConvertOpDescVarsToVarNameMap( + const google::protobuf::RepeatedPtrField& op_desc_vars) { + VarNameMap ret_val; + for (auto& var : op_desc_vars) { + auto& var_names = ret_val[var.parameter()]; + auto& var_names_in_proto = var.arguments(); var_names.reserve(static_cast(var_names_in_proto.size())); std::copy(var_names_in_proto.begin(), var_names_in_proto.end(), std::back_inserter(var_names)); } + return ret_val; + } + static std::shared_ptr CreateOp(const OpDesc& op_desc) { + VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); + VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); AttributeMap attrs; for (auto& attr : op_desc.attrs()) { attrs[attr.name()] = GetAttrValue(attr); diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 456a967629..ec7430a95f 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op").SetMultiple(); - AddOutput("output", "output of cosine op").SetTemporary(); + AddInput("input", "input of cosine op").SetDuplicable(); + AddOutput("output", "output of cosine op").SetIntermediate(); auto my_checker = [](int i) { PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); }; @@ -51,6 +51,15 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } // namespace framework } // namespace paddle +static void ConstructVars(const std::string& param_name, + std::initializer_list arguments, + paddle::framework::OpDesc::Var* var) { + var->set_parameter(param_name); + for (auto& arg_name : arguments) { + *var->mutable_arguments()->Add() = arg_name; + } +} + REGISTER_OP(cos_sim, paddle::framework::CosineOp, paddle::framework::CosineOpProtoAndCheckerMaker); REGISTER_OP(my_test_op, paddle::framework::MyTestOp, @@ -59,13 +68,11 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp, TEST(OpRegistry, CreateOp) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - auto input = op_desc.add_inputs(); - input->set_parameter("input"); - *input->mutable_arguments()->Add() = "aa"; + auto* input = op_desc.add_inputs(); + ConstructVars("input", {"aa"}, input); - auto output = op_desc.add_outputs(); - output->set_parameter("output"); - *output->mutable_arguments()->Add() = "bb"; + auto* output = op_desc.add_outputs(); + ConstructVars("output", {"bb"}, output); float scale = 3.3; auto attr = op_desc.mutable_attrs()->Add(); @@ -85,13 +92,11 @@ TEST(OpRegistry, CreateOp) { TEST(OpRegistry, IllegalAttr) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - auto input = op_desc.add_inputs(); - input->set_parameter("input"); - *input->mutable_arguments()->Add() = "aa"; + auto* input = op_desc.add_inputs(); + ConstructVars("input", {"aa"}, input); - auto output = op_desc.add_outputs(); - output->set_parameter("output"); - *output->mutable_arguments()->Add() = "bb"; + auto* output = op_desc.add_outputs(); + ConstructVars("output", {"bb"}, output); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -115,13 +120,11 @@ TEST(OpRegistry, IllegalAttr) { TEST(OpRegistry, DefaultValue) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - auto input = op_desc.add_inputs(); - input->set_parameter("input"); - *input->mutable_arguments()->Add() = "aa"; + auto* input = op_desc.add_inputs(); + ConstructVars("input", {"aa"}, input); - auto output = op_desc.add_outputs(); - output->set_parameter("output"); - *output->mutable_arguments()->Add() = "bb"; + auto* output = op_desc.add_outputs(); + ConstructVars("output", {"bb"}, output); ASSERT_TRUE(op_desc.IsInitialized()); @@ -136,13 +139,11 @@ TEST(OpRegistry, DefaultValue) { TEST(OpRegistry, CustomChecker) { paddle::framework::OpDesc op_desc; op_desc.set_type("my_test_op"); - auto input = op_desc.add_inputs(); - input->set_parameter("input"); - *input->mutable_arguments()->Add() = "ii"; + auto* input = op_desc.add_inputs(); + ConstructVars("input", {"ii"}, input); - auto output = op_desc.add_outputs(); - output->set_parameter("output"); - *output->mutable_arguments()->Add() = "oo"; + auto* output = op_desc.add_outputs(); + ConstructVars("output", {"oo"}, output); // attr 'test_attr' is not set bool caught = false; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index b54d0b40ce..351a544c0b 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -42,33 +42,35 @@ std::unordered_map& OpProtos() { } const std::string& OperatorBase::Input(const std::string& name) const { - auto it = inputs_.find(name); - PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have input %s", type_, - name); - PADDLE_ENFORCE_EQ(it->second.size(), 1UL, + auto& ins = Inputs(name); + PADDLE_ENFORCE_EQ(ins.size(), 1UL, "Op %s input %s should contain only one variable", type_, name); - return it->second[0]; + return ins[0]; } const std::vector& OperatorBase::Inputs( const std::string& name) const { - return inputs_.at(name); + auto it = inputs_.find(name); + PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_, + name); + return it->second; } const std::string& OperatorBase::Output(const std::string& name) const { - auto it = outputs_.find(name); - PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_, - name); - PADDLE_ENFORCE_EQ(it->second.size(), 1UL, - "Op %s input %s should contain only one variable", type_, + auto& outs = Outputs(name); + PADDLE_ENFORCE_EQ(outs.size(), 1UL, + "Op %s output %s should contain only one variable", type_, name); - return it->second[0]; + return outs[0]; } const std::vector& OperatorBase::Outputs( const std::string& name) const { - return outputs_.at(name); + auto it = outputs_.find(name); + PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_, + name); + return it->second; } std::string OperatorBase::DebugString() const { @@ -120,5 +122,34 @@ void OperatorBase::Rename(const std::string& old_name, } } +std::vector OperatorBase::OutputVars(bool has_intermediate) const { + std::vector ret_val; + if (has_intermediate) { + // push all outputs into ret_val + for (auto& o : outputs_) { + ret_val.reserve(ret_val.size() + o.second.size()); + ret_val.insert(ret_val.end(), o.second.begin(), o.second.end()); + } + return ret_val; + } + auto it = OpProtos().find(type_); + PADDLE_ENFORCE( + it != OpProtos().end(), + "Operator %s not registered, cannot figure out intermediate outputs", + type_); + + // get all OpProto::Var for outputs + for (auto& o : it->second.outputs()) { + // ignore all intermediate output + if (o.intermediate()) continue; + auto out = outputs_.find(o.name()); + if (out != outputs_.end()) { + ret_val.reserve(ret_val.size() + out->second.size()); + ret_val.insert(ret_val.end(), out->second.begin(), out->second.end()); + } + } + return ret_val; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index b5a409a23e..e145649d30 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -116,34 +116,7 @@ class OperatorBase { //! TODO add a vector_view to prevent memory copy. const std::vector& Outputs(const std::string& name) const; - virtual std::vector OutputVars(bool has_intermediate) const { - std::vector ret_val; - if (has_intermediate) { - // push all outputs into ret_val - for (auto& o : outputs_) { - ret_val.reserve(ret_val.size() + o.second.size()); - ret_val.insert(ret_val.end(), o.second.begin(), o.second.end()); - } - return ret_val; - } - auto it = OpProtos().find(type_); - PADDLE_ENFORCE( - it != OpProtos().end(), - "Operator %s not registered, cannot figure out intermediate outputs", - type_); - - // get all OpProto::Var for outputs - for (auto& o : it->second.outputs()) { - // ignore all intermediate output - if (o.intermediate()) continue; - auto out = outputs_.find(o.name()); - if (out != outputs_.end()) { - ret_val.reserve(ret_val.size() + out->second.size()); - ret_val.insert(ret_val.end(), out->second.begin(), out->second.end()); - } - } - return ret_val; - } + virtual std::vector OutputVars(bool has_intermediate) const; std::string Type() const { return type_; } const AttributeMap& Attrs() const { return attrs_; } @@ -154,11 +127,11 @@ class OperatorBase { // I (Inputs) // O (Outputs) // OG (Output Gradients) - std::map> inputs_; + VarNameMap inputs_; // NOTE: in case of OpGrad, outputs_ contains // IG (Inputs Gradients) - std::map> outputs_; + VarNameMap outputs_; AttributeMap attrs_; }; @@ -177,11 +150,11 @@ class InferShapeContext { : op_(op), scope_(scope) {} size_t InputSize(const std::string& name) const { - return op_.inputs_.at(name).size(); + return op_.Inputs(name).size(); } size_t OutputSize(const std::string& name) const { - return op_.outputs_.at(name).size(); + return op_.Outputs(name).size(); } const Variable* InputVar(const std::string& name) const { diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 5fdb6bca02..46e419a8c8 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -56,19 +56,28 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } // namespace framework } // namespace paddle +static void ConstructVars(const std::string& param_name, + std::initializer_list arguments, + paddle::framework::OpDesc::Var* var) { + var->set_parameter(param_name); + for (auto& arg_name : arguments) { + *var->mutable_arguments()->Add() = arg_name; + } +} + REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); TEST(OperatorBase, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("test_operator"); + auto* ipt = op_desc.mutable_inputs()->Add(); - *ipt->mutable_arguments()->Add() = "IN1"; - ipt->set_parameter("input"); + ConstructVars("IN1", {"input"}, ipt); auto* output = op_desc.mutable_outputs()->Add(); - *output->mutable_arguments()->Add() = "OUT1"; - output->set_parameter("output"); + ConstructVars("OUT1", {"output"}, output); + auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); @@ -127,9 +136,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("xs", "inputs of test op").SetMultiple(); + AddInput("xs", "inputs of test op").SetDuplicable(); AddInput("k", "input of test op"); - AddOutput("ys", "outputs of test op").SetMultiple(); + AddOutput("ys", "outputs of test op").SetDuplicable(); AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .LargerThan(0.0); @@ -187,12 +196,10 @@ TEST(OpKernel, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("op_with_kernel"); auto* ipt = op_desc.mutable_inputs()->Add(); - *ipt->mutable_arguments()->Add() = "IN1"; - ipt->set_parameter("x"); + ConstructVars("IN1", {"x"}, ipt); auto* output = op_desc.mutable_outputs()->Add(); - *output->mutable_arguments()->Add() = "OUT1"; - output->set_parameter("y"); + ConstructVars("OUT1", {"y"}, output); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -219,18 +226,12 @@ TEST(OpKernel, multi_inputs) { OpDesc op_desc; op_desc.set_type("op_multi_inputs_with_kernel"); - auto x = op_desc.mutable_inputs()->Add(); - x->set_parameter("xs"); - *x->mutable_arguments()->Add() = "x0"; - *x->mutable_arguments()->Add() = "x1"; - *x->mutable_arguments()->Add() = "x2"; - auto k = op_desc.mutable_inputs()->Add(); - k->set_parameter("k"); - *k->mutable_arguments()->Add() = "k0"; - auto y = op_desc.mutable_outputs()->Add(); - y->set_parameter("ys"); - *y->mutable_arguments()->Add() = "y0"; - *y->mutable_arguments()->Add() = "y1"; + auto* x = op_desc.mutable_inputs()->Add(); + ConstructVars("xs", {"x0", "x1", "x2"}, x); + auto* k = op_desc.mutable_inputs()->Add(); + ConstructVars("k", {"k0"}, k); + auto* y = op_desc.mutable_outputs()->Add(); + ConstructVars("ys", {"y0", "y1"}, y); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 99e27a11a8..6e28c294b1 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -32,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op").IgnoreGradient(); + AddOutput("Out", "The output of mean op").NoGradient(); AddComment("Mean Operator"); } }; diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 4ed338359e..ff02b69276 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -152,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker // inputs and outputs stored in proto AddInput(name.inlinks, "the inputs that need to be segmented for each step.") - .SetMultiple(); + .SetDuplicable(); AddInput(name.boot_memories, "variables to initialize memories.") - .SetMultiple(); + .SetDuplicable(); AddInput(name.step_net, "network shared by all steps."); AddOutput(name.outlinks, "the outputs that need to concated for all steps.") - .SetMultiple(); + .SetDuplicable(); AddOutput(name.step_scopes, "step scopes"); // Attributes stored in AttributeMap diff --git a/paddle/operators/recurrent_op_test.cc b/paddle/operators/recurrent_op_test.cc index 40c212d6b7..2f6eff0720 100644 --- a/paddle/operators/recurrent_op_test.cc +++ b/paddle/operators/recurrent_op_test.cc @@ -26,8 +26,6 @@ namespace paddle { namespace operators { using namespace paddle::framework; -// using framework::make_ddim; -// using framework::DDim; class RecurrentGradientAlgorithmTest : public ::testing::Test { protected: diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_two_op.py index 019784a8b4..0def484edd 100644 --- a/python/paddle/v2/framework/tests/test_add_two_op.py +++ b/python/paddle/v2/framework/tests/test_add_two_op.py @@ -19,13 +19,5 @@ class TestAddOp(unittest.TestCase): self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} -#class TestAddGradOp(unittest.TestCase): -# def test_add_grad(self): -# op = Operator('add_two', X="X", Y="Y", Out="Out") -# backward_op = core.Operator.backward(op, set()) -# self.assertEqual(backward_op.type(), "add_two_grad") -# expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).''' -# self.assertEqual(expected, str(backward_op)) - if __name__ == '__main__': unittest.main() -- GitLab