提交 ff8766e9 编写于 作者: Y Yu Yang

Stash

上级 578a357b
...@@ -378,6 +378,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -378,6 +378,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
+ 1UL /* external output number*/ + 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/ + 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/); + 2U /* internal variable number*/);
std::cerr << grad_fc.DebugString() << std::endl;
EXPECT_EQ(grad_fc.Outputs(all).size(), EXPECT_EQ(grad_fc.Outputs(all).size(),
2UL /* input number of mul*/ 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add + 2UL /* input number of rowwise_add
......
...@@ -85,6 +85,7 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -85,6 +85,7 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
info->proto_ = new OpProto; info->proto_ = new OpProto;
info->checker_ = new OpAttrChecker(); info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_); auto maker = T(info->proto_, info->checker_);
std::cerr << "Assign Maker " << op_type << std::endl;
maker.Validate(); maker.Validate();
info->proto_->set_type(op_type); info->proto_->set_type(op_type);
PADDLE_ENFORCE( PADDLE_ENFORCE(
......
...@@ -98,7 +98,7 @@ class OpDescBind { ...@@ -98,7 +98,7 @@ class OpDescBind {
std::vector<typename MapType::key_type> ret_val; std::vector<typename MapType::key_type> ret_val;
ret_val.reserve(map.size()); ret_val.reserve(map.size());
std::transform( std::transform(
map.begin(), map.end(), ret_val.begin(), map.begin(), map.end(), std::back_inserter(ret_val),
[](const typename MapType::value_type &pair) { return pair.first; }); [](const typename MapType::value_type &pair) { return pair.first; });
return ret_val; return ret_val;
} }
......
...@@ -42,19 +42,11 @@ struct OpInfo { ...@@ -42,19 +42,11 @@ struct OpInfo {
return *proto_; return *proto_;
} }
const OpAttrChecker& Checker() const {
PADDLE_ENFORCE_NOT_NULL(checker_,
"Operator Checker has not been registered");
return *checker_;
}
const OpCreator& Creator() const { const OpCreator& Creator() const {
PADDLE_ENFORCE_NOT_NULL(creator_, PADDLE_ENFORCE_NOT_NULL(creator_,
"Operator Creator has not been registered"); "Operator Creator has not been registered");
return creator_; return creator_;
} }
bool HasGradientOp() const { return !grad_op_type_.empty(); }
}; };
class OpInfoMap { class OpInfoMap {
......
...@@ -44,11 +44,6 @@ class OpProtoAndCheckerMaker { ...@@ -44,11 +44,6 @@ class OpProtoAndCheckerMaker {
var_->set_intermediate(true); var_->set_intermediate(true);
return *this; return *this;
} }
VariableBuilder& NotInGradient() {
var_->set_not_in_gradient(true);
return *this;
}
}; };
VariableBuilder AddInput(const std::string& name, const std::string& comment); VariableBuilder AddInput(const std::string& name, const std::string& comment);
......
...@@ -23,7 +23,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp( ...@@ -23,7 +23,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const std::string& type, const VariableNameMap& inputs, const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, AttributeMap attrs) { const VariableNameMap& outputs, AttributeMap attrs) {
auto& info = OpInfoMap::Instance().Get(type); auto& info = OpInfoMap::Instance().Get(type);
info.Checker().Check(attrs); if (info.checker_ != nullptr) {
info.checker_->Check(attrs);
}
auto op = info.Creator()(type, inputs, outputs, attrs); auto op = info.Creator()(type, inputs, outputs, attrs);
return std::unique_ptr<OperatorBase>(op); return std::unique_ptr<OperatorBase>(op);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册