提交 578a357b 编写于 作者: Y Yu Yang

Make compile pass

上级 e98aac51
......@@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op)
py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
......
......@@ -13,6 +13,7 @@
limitations under the License. */
#include "paddle/framework/backward.h"
#include "paddle/operators/net_op.h"
#include <list>
#include <memory>
......@@ -24,6 +25,32 @@
namespace paddle {
namespace framework {
static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op) {
OpDescBind op_desc;
op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.grad_op_maker_(op_desc);
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size());
std::transform(
grad_descs.begin(), grad_descs.end(), std::back_inserter(grad_ops),
[](OpDescBind& grad_desc) { return OpRegistry::CreateOp(&grad_desc); });
PADDLE_ENFORCE_GT(grad_ops.size(), 0);
if (grad_ops.size() == 1) {
return std::move(grad_ops[0]);
} else {
auto net_op = new operators::NetOp();
for (auto& grad_op : grad_ops) {
net_op->AppendOp(std::move(grad_op));
}
return std::unique_ptr<OperatorBase>(net_op);
}
}
template <typename Map, typename T>
static void ForEachVarName(const Map& names, T callback) {
for (auto& name : names) {
......@@ -154,10 +181,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net->InsertOp(pos.first + 1, std::move(pos.second));
}
} else {
OpDescBind fwd_desc;
fwd_desc.SetInput(forwardOp.Inputs());
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
std::unique_ptr<OperatorBase> grad_op(CreateGradOp(forwardOp));
PADDLE_ENFORCE(grad_op != nullptr);
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) {
......
......@@ -76,18 +76,22 @@ class OpDescBind {
return MapKeys(outputs_);
}
void SetInput(
const std::unordered_map<std::string, std::vector<std::string>> &input) {
void SetInputMap(const VariableNameMap &input) {
this->inputs_ = input;
this->need_update_ = true;
}
void SetOutput(
const std::unordered_map<std::string, std::vector<std::string>> &output) {
void SetOutputMap(const VariableNameMap &output) {
this->outputs_ = output;
this->need_update_ = true;
}
void Sync();
const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; }
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......@@ -99,8 +103,6 @@ class OpDescBind {
return ret_val;
}
void Sync();
OpDesc op_desc_;
VariableNameMap inputs_;
VariableNameMap outputs_;
......
......@@ -52,5 +52,11 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(OpDescBind* op_desc) {
op_desc->Sync();
return CreateOp(op_desc->Type(), op_desc->Inputs(), op_desc->Outputs(),
op_desc->GetAttrMap());
}
} // namespace framework
} // namespace paddle
......@@ -23,8 +23,8 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/details/op_registry.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_desc_maker.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
......@@ -46,15 +46,15 @@ class Registrar {
template <typename... ARGS>
struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const char* op_type) : op_type(op_type) {
std::cerr << "Reg operator " << op_type << std::endl;
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type);
static_assert(sizeof...(ARGS) != 0,
"OperatorRegistrar should be invoked at least by OpClass");
details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info);
OpInfoMap::Instance().Insert(op_type, info);
}
~OperatorRegistrar() { OpInfoMap::Instance().Insert(op_type, info); }
const char* op_type;
OpInfo info;
......@@ -79,6 +79,8 @@ class OpRegistry {
AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(OpDescBind* op_desc);
};
template <typename OpType, typename ProtoMakerType, typename GradOpType>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册