提交 578a357b 编写于 作者: Y Yu Yang

Make compile pass

上级 e98aac51
...@@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) ...@@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op)
py_proto_compile(framework_py_proto SRCS framework.proto) py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include "paddle/operators/net_op.h"
#include <list> #include <list>
#include <memory> #include <memory>
...@@ -24,6 +25,32 @@ ...@@ -24,6 +25,32 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op) {
OpDescBind op_desc;
op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.grad_op_maker_(op_desc);
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size());
std::transform(
grad_descs.begin(), grad_descs.end(), std::back_inserter(grad_ops),
[](OpDescBind& grad_desc) { return OpRegistry::CreateOp(&grad_desc); });
PADDLE_ENFORCE_GT(grad_ops.size(), 0);
if (grad_ops.size() == 1) {
return std::move(grad_ops[0]);
} else {
auto net_op = new operators::NetOp();
for (auto& grad_op : grad_ops) {
net_op->AppendOp(std::move(grad_op));
}
return std::unique_ptr<OperatorBase>(net_op);
}
}
template <typename Map, typename T> template <typename Map, typename T>
static void ForEachVarName(const Map& names, T callback) { static void ForEachVarName(const Map& names, T callback) {
for (auto& name : names) { for (auto& name : names) {
...@@ -154,10 +181,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -154,10 +181,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net->InsertOp(pos.first + 1, std::move(pos.second)); net->InsertOp(pos.first + 1, std::move(pos.second));
} }
} else { } else {
OpDescBind fwd_desc; std::unique_ptr<OperatorBase> grad_op(CreateGradOp(forwardOp));
fwd_desc.SetInput(forwardOp.Inputs()); PADDLE_ENFORCE(grad_op != nullptr);
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) { const std::string& grad_input) {
......
...@@ -76,18 +76,22 @@ class OpDescBind { ...@@ -76,18 +76,22 @@ class OpDescBind {
return MapKeys(outputs_); return MapKeys(outputs_);
} }
void SetInput( void SetInputMap(const VariableNameMap &input) {
const std::unordered_map<std::string, std::vector<std::string>> &input) {
this->inputs_ = input; this->inputs_ = input;
this->need_update_ = true; this->need_update_ = true;
} }
void SetOutput( void SetOutputMap(const VariableNameMap &output) {
const std::unordered_map<std::string, std::vector<std::string>> &output) {
this->outputs_ = output; this->outputs_ = output;
this->need_update_ = true; this->need_update_ = true;
} }
void Sync();
const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; }
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
...@@ -99,8 +103,6 @@ class OpDescBind { ...@@ -99,8 +103,6 @@ class OpDescBind {
return ret_val; return ret_val;
} }
void Sync();
OpDesc op_desc_; OpDesc op_desc_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
......
...@@ -52,5 +52,11 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) { ...@@ -52,5 +52,11 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(OpDescBind* op_desc) {
op_desc->Sync();
return CreateOp(op_desc->Type(), op_desc->Inputs(), op_desc->Outputs(),
op_desc->GetAttrMap());
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -23,8 +23,8 @@ limitations under the License. */ ...@@ -23,8 +23,8 @@ limitations under the License. */
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/details/op_registry.h" #include "paddle/framework/details/op_registry.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_desc_maker.h" #include "paddle/framework/grad_op_desc_maker.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
...@@ -46,15 +46,15 @@ class Registrar { ...@@ -46,15 +46,15 @@ class Registrar {
template <typename... ARGS> template <typename... ARGS>
struct OperatorRegistrar : public Registrar { struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const char* op_type) : op_type(op_type) { explicit OperatorRegistrar(const char* op_type) : op_type(op_type) {
std::cerr << "Reg operator " << op_type << std::endl;
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type); "'%s' is registered more than once.", op_type);
static_assert(sizeof...(ARGS) != 0, static_assert(sizeof...(ARGS) != 0,
"OperatorRegistrar should be invoked at least by OpClass"); "OperatorRegistrar should be invoked at least by OpClass");
details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info); details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info);
OpInfoMap::Instance().Insert(op_type, info);
} }
~OperatorRegistrar() { OpInfoMap::Instance().Insert(op_type, info); }
const char* op_type; const char* op_type;
OpInfo info; OpInfo info;
...@@ -79,6 +79,8 @@ class OpRegistry { ...@@ -79,6 +79,8 @@ class OpRegistry {
AttributeMap attrs); AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(OpDescBind* op_desc);
}; };
template <typename OpType, typename ProtoMakerType, typename GradOpType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册