提交 78af6e60 编写于 作者: Y Yu Yang

Add OutputVars method to get all outputs or outputs without intermediate

上级 b368c6ca
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
namespace paddle { namespace paddle {
...@@ -127,7 +128,7 @@ class OpRegistry { ...@@ -127,7 +128,7 @@ class OpRegistry {
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type) {
op_creators()[op_type] = [] { return new OpType; }; op_creators()[op_type] = [] { return new OpType; };
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type]; OpProto& op_proto = OpProtos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker); auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate(); maker.Validate();
*op_proto.mutable_type() = op_type; *op_proto.mutable_type() = op_type;
...@@ -135,17 +136,6 @@ class OpRegistry { ...@@ -135,17 +136,6 @@ class OpRegistry {
op_proto.IsInitialized(), op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized", "Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString()); op_type, op_proto.InitializationErrorString());
VarIndexMaps()[op_type].reset(new VarIndexMap());
auto& varmap = *VarIndexMaps()[op_type];
int idx = 0;
for (auto& var : op_proto.inputs()) {
varmap[var.name()] = idx++;
}
idx = 0;
for (auto& var : op_proto.outputs()) {
varmap[var.name()] = idx++;
}
} }
template <typename GradOpType> template <typename GradOpType>
...@@ -212,22 +202,11 @@ class OpRegistry { ...@@ -212,22 +202,11 @@ class OpRegistry {
return grad_op; return grad_op;
} }
static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_;
return protos_;
}
static std::unordered_map<std::string, std::string>& grad_ops() { static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, std::string> grad_ops_; static std::unordered_map<std::string, std::string> grad_ops_;
return grad_ops_; return grad_ops_;
} }
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
VarIndexMaps() {
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
return maps_;
}
static std::unordered_map<std::string, OpCreator>& op_creators() { static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_; static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_; return op_creators_;
......
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -33,6 +33,14 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -33,6 +33,14 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
} }
#endif #endif
static std::unordered_map<std::string, OpProto>* g_op_protos = nullptr;
std::unordered_map<std::string, OpProto>& OpProtos() {
if (g_op_protos == nullptr) {
g_op_protos = new std::unordered_map<std::string, OpProto>();
}
return *g_op_protos;
}
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
auto it = inputs_.find(name); auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have output %s", type_, PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have output %s", type_,
......
...@@ -50,6 +50,8 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -50,6 +50,8 @@ inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix; return var_name + kGradVarSuffix;
} }
extern std::unordered_map<std::string, OpProto>& OpProtos();
class OperatorBase; class OperatorBase;
class InferShapeContext; class InferShapeContext;
class ExecutionContext; class ExecutionContext;
...@@ -103,6 +105,35 @@ class OperatorBase { ...@@ -103,6 +105,35 @@ class OperatorBase {
//! TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
const std::vector<std::string>& Outputs(const std::string& name) const; const std::vector<std::string>& Outputs(const std::string& name) const;
virtual std::vector<std::string> OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
// push all outputs into ret_val
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
auto it = OpProtos().find(type_);
PADDLE_ENFORCE(
it != OpProtos().end(),
"Operator %s not registered, cannot figure out intermediate outputs",
type_);
// get all OpProto::Var for outputs
for (auto& o : it->second.outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
if (out != outputs_.end()) {
ret_val.reserve(ret_val.size() + out->second.size());
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
}
}
return ret_val;
}
public: public:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
......
...@@ -21,19 +21,20 @@ ...@@ -21,19 +21,20 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
const char NetOp::kAll[] = "all";
void NetOp::CompleteAddOp(bool calc) { void NetOp::CompleteAddOp(bool calc) {
add_op_done_ = true; add_op_done_ = true;
if (!calc) return; if (!calc) return;
std::set<std::string> input_set; std::set<std::string> input_set;
std::set<std::string> output_set; std::set<std::string> output_set;
std::set<std::string> temp_output;
for (auto& op : ops_) { for (auto& op : ops_) {
for (auto& ipt : op->inputs_) { for (auto& ipt : op->inputs_) {
for (auto& var_name : ipt.second) { for (auto& var_name : ipt.second) {
if (!Contains(output_set, var_name)) { // Not other op's output if (!Contains(output_set, var_name)) { // Not other op's output
input_set.insert(var_name); input_set.insert(var_name);
} else { } else {
temp_output.insert(var_name); intermediate_outputs_.insert(var_name);
} }
} }
} }
...@@ -44,24 +45,12 @@ void NetOp::CompleteAddOp(bool calc) { ...@@ -44,24 +45,12 @@ void NetOp::CompleteAddOp(bool calc) {
} }
} }
} }
auto& inputs = inputs_["all"]; auto& inputs = inputs_[kAll];
inputs.reserve(input_set.size()); inputs.reserve(input_set.size());
std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs)); std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs));
auto& outputs = outputs_["all"]; auto& outputs = outputs_[kAll];
outputs.reserve(output_set.size()); outputs.reserve(output_set.size());
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs)); std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs));
//! TODO figure out how to generate temporary_index in Network.
std::vector<int> tmp_index;
tmp_index.reserve(temp_output.size());
int output_len = static_cast<int>(outputs.size());
for (int i = 0; i < output_len; ++i) {
if (Contains(temp_output, outputs[i])) {
tmp_index.push_back(i);
}
}
attrs_["temporary_index"] = tmp_index;
} }
std::string NetOp::DebugString() const { std::string NetOp::DebugString() const {
...@@ -78,5 +67,19 @@ std::string NetOp::DebugString() const { ...@@ -78,5 +67,19 @@ std::string NetOp::DebugString() const {
bool NetOp::IsNetOp() const { return true; } bool NetOp::IsNetOp() const { return true; }
std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
if (has_intermediate) {
return this->outputs_.at(kAll);
}
auto& all = this->outputs_.at(kAll);
std::vector<std::string> ret_val;
for (auto& each : all) {
if (!Contains(intermediate_outputs_, each)) {
ret_val.push_back(each);
}
}
return ret_val;
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -36,6 +36,8 @@ namespace operators { ...@@ -36,6 +36,8 @@ namespace operators {
*/ */
class NetOp : public framework::OperatorBase { class NetOp : public framework::OperatorBase {
public: public:
static const char kAll[];
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch * before every mini-batch
...@@ -91,11 +93,13 @@ class NetOp : public framework::OperatorBase { ...@@ -91,11 +93,13 @@ class NetOp : public framework::OperatorBase {
std::string DebugString() const override; std::string DebugString() const override;
bool IsNetOp() const override; bool IsNetOp() const override;
std::vector<std::string> OutputVars(bool has_intermediate) const override;
std::vector<std::shared_ptr<OperatorBase>> ops_; std::vector<std::shared_ptr<OperatorBase>> ops_;
private: private:
bool add_op_done_{false}; bool add_op_done_{false};
std::set<std::string> intermediate_outputs_;
template <typename T, typename KeyType> template <typename T, typename KeyType>
static bool Contains(T container, KeyType key) { static bool Contains(T container, KeyType key) {
......
...@@ -54,22 +54,13 @@ TEST(OpKernel, all) { ...@@ -54,22 +54,13 @@ TEST(OpKernel, all) {
net->CompleteAddOp(); net->CompleteAddOp();
AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"},
net->inputs_.at("__all__")); net->inputs_.at(NetOp::kAll));
AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at("__all__")); AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at(NetOp::kAll));
auto tmp_idx_iter = net->attrs_.find("temporary_index");
ASSERT_NE(net->attrs_.end(), tmp_idx_iter);
auto& tmp_idx = boost::get<std::vector<int>>(tmp_idx_iter->second);
ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_.at("__all__")[tmp_idx[0]]);
Scope scope; auto final_outs = net->OutputVars(false);
platform::CPUDeviceContext dev_ctx;
net->InferShape(scope); ASSERT_EQ(final_outs.size(), 1UL);
net->Run(scope, dev_ctx); ASSERT_EQ(final_outs[0], "z");
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), platform::EnforceNotMet);
} }
TEST(NetOp, insert_op) { TEST(NetOp, insert_op) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册