提交 1ed5f02d 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #14 from reyoung/feature/refactorize_framework_proto

Polish Our code by YuYang's review
......@@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").IgnoreGradient();
AddInput("b", "Bias of Add").IgnoreGradient();
AddOutput("Out", "Out of Add").IgnoreGradient();
AddInput("X", "Input X of Add").AsNoGradient();
AddInput("b", "Bias of Add").AsNoGradient();
AddOutput("Out", "Out of Add").AsNoGradient();
AddComment("Add Op");
}
};
......@@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "x");
AddInput("W", "w");
AddInput("b", "b");
AddOutput("mul_result", "").SetTemporary();
AddOutput("add_result", "").SetTemporary();
AddOutput("mul_result", "").AsIntermediate();
AddOutput("add_result", "").AsIntermediate();
AddOutput("Out", "");
AddComment("");
}
......@@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").SetMultiple();
AddInput("X", "x").AsDuplicable();
AddOutput("Y", "y");
AddComment("");
}
......@@ -392,18 +392,20 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
auto bwd_net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_["all"].size(),
const char *all = paddle::operators::NetOp::kAll;
EXPECT_EQ(grad_fc.inputs_[all].size(),
2UL /* external input number */
+ 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_["all"].size(),
EXPECT_EQ(grad_fc.outputs_[all].size(),
2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add
*/
+ 1UL /* input number of sigmod */);
EXPECT_EQ(bwd_net->ops_[1]->inputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->outputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->inputs_[all].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->outputs_[all].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_[all].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_[all].size(), 0UL);
}
......@@ -283,12 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list<int> init_list) {
*this = make_ddim(init_list);
}
std::string DDim::DebugString() const {
std::ostringstream ss;
ss << *this;
return ss.str();
}
} // namespace framework
} // namespace paddle
......@@ -72,8 +72,6 @@ struct DDim {
DDim operator*(DDim d) const;
ssize_t size() const;
std::string DebugString() const;
};
/**
......
......@@ -18,9 +18,6 @@ permissions and limitations under the License. */
namespace paddle {
namespace framework {
class OpRegistry;
enum class OpArgType { IN, OUT };
static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
......
......@@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").SetMultiple();
AddInput("In2_mult", "a multiple input").AsDuplicable();
AddInput("In3", "another single input");
AddOutput("Out1", "a single output");
AddOutput("Out2_mult", "a multiple output").SetMultiple();
AddOutput("Out2_mult", "a multiple output").AsDuplicable();
AddComment("test op with multiple inputs and outputs");
}
};
......@@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").SetMultiple().IgnoreGradient();
AddInput("In3_mult", "another multiple input").SetMultiple();
AddOutput("Out1_mult", "a multiple output").SetMultiple();
AddOutput("Out2", "a single output").IgnoreGradient();
AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient();
AddInput("In3_mult", "another multiple input").AsDuplicable();
AddOutput("Out1_mult", "a multiple output").AsDuplicable();
AddOutput("Out2", "a single output").AsNoGradient();
AddComment("op with inputs and outputs ignored in gradient calculating");
}
};
......
......@@ -47,17 +47,20 @@ class OpProtoAndCheckerMaker {
struct VariableBuilder {
OpProto::Var* var_;
VariableBuilder& SetMultiple() {
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& SetTemporary() {
VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
VariableBuilder& IgnoreGradient() {
// TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it
// means that input/output is not needed when calculate gradient. It does
// not mean no gradient when backward. It should be changed soon.
VariableBuilder& AsNoGradient() {
var_->set_no_gradient(true);
return *this;
}
......@@ -118,7 +121,7 @@ class OpProtoAndCheckerMaker {
class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarNameMap = std::map<std::string, std::vector<std::string>>;
using VarNameMap = OperatorBase::VarNameMap;
public:
template <typename OpType, typename ProtoMakerType>
......@@ -164,25 +167,22 @@ class OpRegistry {
return std::shared_ptr<OperatorBase>(op);
}
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
VarNameMap inputs;
for (auto& input : op_desc.inputs()) {
auto& var_names = inputs[input.parameter()];
auto& var_names_in_proto = input.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
}
VarNameMap outputs;
for (auto& output : op_desc.outputs()) {
auto& var_names = outputs[output.parameter()];
auto& var_names_in_proto = output.arguments();
static VarNameMap ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars) {
VarNameMap ret_val;
for (auto& var : op_desc_vars) {
auto& var_names = ret_val[var.parameter()];
auto& var_names_in_proto = var.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
}
return ret_val;
}
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr);
......
......@@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op").SetMultiple();
AddOutput("output", "output of cosine op").SetTemporary();
AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
......@@ -51,6 +51,15 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} // namespace framework
} // namespace paddle
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
var->add_arguments(arg_name);
}
}
REGISTER_OP(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
......@@ -59,13 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs();
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add();
......@@ -85,13 +89,8 @@ TEST(OpRegistry, CreateOp) {
TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs();
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......@@ -115,13 +114,8 @@ TEST(OpRegistry, IllegalAttr) {
TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs();
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
ASSERT_TRUE(op_desc.IsInitialized());
......@@ -136,13 +130,8 @@ TEST(OpRegistry, DefaultValue) {
TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op");
auto input = op_desc.add_inputs();
input->set_parameter("input");
*input->mutable_arguments()->Add() = "ii";
auto output = op_desc.add_outputs();
output->set_parameter("output");
*output->mutable_arguments()->Add() = "oo";
BuildVar("input", {"ii"}, op_desc.add_inputs());
BuildVar("output", {"oo"}, op_desc.add_outputs());
// attr 'test_attr' is not set
bool caught = false;
......
......@@ -42,33 +42,35 @@ std::unordered_map<std::string, OpProto>& OpProtos() {
}
const std::string& OperatorBase::Input(const std::string& name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have input %s", type_,
name);
PADDLE_ENFORCE_EQ(it->second.size(), 1UL,
auto& ins = Inputs(name);
PADDLE_ENFORCE_EQ(ins.size(), 1UL,
"Op %s input %s should contain only one variable", type_,
name);
return it->second[0];
return ins[0];
}
const std::vector<std::string>& OperatorBase::Inputs(
const std::string& name) const {
return inputs_.at(name);
auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_,
name);
return it->second;
}
const std::string& OperatorBase::Output(const std::string& name) const {
auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_,
name);
PADDLE_ENFORCE_EQ(it->second.size(), 1UL,
"Op %s input %s should contain only one variable", type_,
auto& outs = Outputs(name);
PADDLE_ENFORCE_EQ(outs.size(), 1UL,
"Op %s output %s should contain only one variable", type_,
name);
return it->second[0];
return outs[0];
}
const std::vector<std::string>& OperatorBase::Outputs(
const std::string& name) const {
return outputs_.at(name);
auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_,
name);
return it->second;
}
std::string OperatorBase::DebugString() const {
......@@ -120,5 +122,34 @@ void OperatorBase::Rename(const std::string& old_name,
}
}
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
// push all outputs into ret_val
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
auto it = OpProtos().find(type_);
PADDLE_ENFORCE(
it != OpProtos().end(),
"Operator %s not registered, cannot figure out intermediate outputs",
type_);
// get all OpProto::Var for outputs
for (auto& o : it->second.outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
if (out != outputs_.end()) {
ret_val.reserve(ret_val.size() + out->second.size());
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
}
}
return ret_val;
}
} // namespace framework
} // namespace paddle
......@@ -116,34 +116,7 @@ class OperatorBase {
//! TODO add a vector_view to prevent memory copy.
const std::vector<std::string>& Outputs(const std::string& name) const;
virtual std::vector<std::string> OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
// push all outputs into ret_val
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
auto it = OpProtos().find(type_);
PADDLE_ENFORCE(
it != OpProtos().end(),
"Operator %s not registered, cannot figure out intermediate outputs",
type_);
// get all OpProto::Var for outputs
for (auto& o : it->second.outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
if (out != outputs_.end()) {
ret_val.reserve(ret_val.size() + out->second.size());
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
}
}
return ret_val;
}
virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
std::string Type() const { return type_; }
const AttributeMap& Attrs() const { return attrs_; }
......@@ -154,11 +127,11 @@ class OperatorBase {
// I (Inputs)
// O (Outputs)
// OG (Output Gradients)
std::map<std::string, std::vector<std::string>> inputs_;
VarNameMap inputs_;
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
std::map<std::string, std::vector<std::string>> outputs_;
VarNameMap outputs_;
AttributeMap attrs_;
};
......@@ -177,11 +150,11 @@ class InferShapeContext {
: op_(op), scope_(scope) {}
size_t InputSize(const std::string& name) const {
return op_.inputs_.at(name).size();
return op_.Inputs(name).size();
}
size_t OutputSize(const std::string& name) const {
return op_.outputs_.at(name).size();
return op_.Outputs(name).size();
}
const Variable* InputVar(const std::string& name) const {
......
......@@ -56,19 +56,24 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} // namespace framework
} // namespace paddle
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}
REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest,
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("test_operator");
auto* ipt = op_desc.mutable_inputs()->Add();
*ipt->mutable_arguments()->Add() = "IN1";
ipt->set_parameter("input");
BuildVar("IN1", {"input"}, op_desc.add_inputs());
BuildVar("OUT1", {"output"}, op_desc.add_outputs());
auto* output = op_desc.mutable_outputs()->Add();
*output->mutable_arguments()->Add() = "OUT1";
output->set_parameter("output");
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
......@@ -127,9 +132,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("xs", "inputs of test op").SetMultiple();
AddInput("xs", "inputs of test op").AsDuplicable();
AddInput("k", "input of test op");
AddOutput("ys", "outputs of test op").SetMultiple();
AddOutput("ys", "outputs of test op").AsDuplicable();
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
......@@ -186,13 +191,8 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
auto* ipt = op_desc.mutable_inputs()->Add();
*ipt->mutable_arguments()->Add() = "IN1";
ipt->set_parameter("x");
auto* output = op_desc.mutable_outputs()->Add();
*output->mutable_arguments()->Add() = "OUT1";
output->set_parameter("y");
BuildVar("IN1", {"x"}, op_desc.add_inputs());
BuildVar("OUT1", {"y"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......@@ -219,18 +219,9 @@ TEST(OpKernel, multi_inputs) {
OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
auto x = op_desc.mutable_inputs()->Add();
x->set_parameter("xs");
*x->mutable_arguments()->Add() = "x0";
*x->mutable_arguments()->Add() = "x1";
*x->mutable_arguments()->Add() = "x2";
auto k = op_desc.mutable_inputs()->Add();
k->set_parameter("k");
*k->mutable_arguments()->Add() = "k0";
auto y = op_desc.mutable_outputs()->Add();
y->set_parameter("ys");
*y->mutable_arguments()->Add() = "y0";
*y->mutable_arguments()->Add() = "y1";
BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
BuildVar("k", {"k0"}, op_desc.add_inputs());
BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......
......@@ -32,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").IgnoreGradient();
AddOutput("Out", "The output of mean op").AsNoGradient();
AddComment("Mean Operator");
}
};
......
......@@ -152,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker
// inputs and outputs stored in proto
AddInput(name.inlinks,
"the inputs that need to be segmented for each step.")
.SetMultiple();
.AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.")
.SetMultiple();
.AsDuplicable();
AddInput(name.step_net, "network shared by all steps.");
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.SetMultiple();
.AsDuplicable();
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
......
......@@ -26,8 +26,6 @@ namespace paddle {
namespace operators {
using namespace paddle::framework;
// using framework::make_ddim;
// using framework::DDim;
class RecurrentGradientAlgorithmTest : public ::testing::Test {
protected:
......
......@@ -19,13 +19,5 @@ class TestAddOp(unittest.TestCase):
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
#class TestAddGradOp(unittest.TestCase):
# def test_add_grad(self):
# op = Operator('add_two', X="X", Y="Y", Out="Out")
# backward_op = core.Operator.backward(op, set())
# self.assertEqual(backward_op.type(), "add_two_grad")
# expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).'''
# self.assertEqual(expected, str(backward_op))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册