提交 32022d4d 编写于 作者: W willzhang4a58

blob desc

上级 0e9c3917
......@@ -108,6 +108,8 @@ foreach(cc ${of_main_cc})
endforeach()
# build test
cuda_add_executable(oneflow_testexe ${of_all_test_cc})
target_link_libraries(oneflow_testexe ${of_libs} ${oneflow_third_party_libs})
add_test(NAME oneflow_test COMMAND oneflow_testexe)
if(BUILD_TESTING)
cuda_add_executable(oneflow_testexe ${of_all_test_cc})
target_link_libraries(oneflow_testexe ${of_libs} ${oneflow_third_party_libs})
add_test(NAME oneflow_test COMMAND oneflow_testexe)
endif()
......@@ -173,11 +173,11 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
}
}
void BoxingTaskNode::FwInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
void BoxingTaskNode::FwInferBlobDescInProducedRegsts(TaskGraph*) {
exec_gph().ConstForEachNode([this](const ExecNode* exec_node) {
exec_node->op()->InferShape4FwBlobs(exec_node->GetMutShapePtr4BnInOpFunc(),
chain_node()->parallel_desc()->policy(),
0, 0);
exec_node->op()->InferBlobDesc4FwBlobs(
exec_node->GetBlobDesc4BnInOpFunc(),
chain_node()->parallel_desc()->policy(), 0, 0);
});
}
......@@ -232,16 +232,16 @@ void BoxingTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
mut_exec_gph().UpdateSourceAndSink();
}
void BoxingTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
void BoxingTaskNode::BpInferBlobDescInProducedRegsts(TaskGraph*) {
for (TaskEdge* fw_in_edge : GetFwNode()->in_edges()) {
auto in_regst = GetRelatedRegst(fw_in_edge);
if (auto in_diff_regst = GetBpRegstFromFwRegst(in_regst)) {
in_diff_regst->CopyShapeFrom(in_regst.get());
in_diff_regst->CopyBlobDescFrom(in_regst.get());
}
}
auto fw_middle_regst = GetFwNode()->GetProducedRegstDesc("middle");
auto bp_middle_regst = GetProducedRegstDesc("middle");
bp_middle_regst->CopyShapeFrom(fw_middle_regst.get());
bp_middle_regst->CopyBlobDescFrom(fw_middle_regst.get());
}
} // namespace oneflow
......@@ -41,12 +41,12 @@ class BoxingTaskNode : public TaskNode {
private:
OVERRIDE_IF_FW_BP_FOR_FUNC(BuildExecAndEnrollLbn2Regsts);
OVERRIDE_IF_FW_BP_FOR_FUNC(InferShapeOfBlobsInProducedRegsts);
OVERRIDE_IF_FW_BP_FOR_FUNC(InferBlobDescInProducedRegsts);
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void FwInferShapeOfBlobsInProducedRegsts(TaskGraph*);
void FwInferBlobDescInProducedRegsts(TaskGraph*);
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void BpInferShapeOfBlobsInProducedRegsts(TaskGraph*);
void BpInferBlobDescInProducedRegsts(TaskGraph*);
void EnrollAllRegstAndBindRelatedEdge();
TaskType task_type() const override { return kBoxingTask; }
......
......@@ -25,10 +25,10 @@ void CopyTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph*) {
mut_exec_gph().UpdateSourceAndSink();
}
void CopyTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph*) {
void CopyTaskNode::InferBlobDescInProducedRegsts(TaskGraph*) {
std::shared_ptr<RegstDesc> in_regst = GetRelatedRegst(SoleInEdge());
std::shared_ptr<RegstDesc> out_regst = GetRelatedRegst(SoleOutEdge());
out_regst->CopyShapeFrom(in_regst.get());
out_regst->CopyBlobDescFrom(in_regst.get());
}
void CopyHDTaskNode::SetFwInCopy() {
......
......@@ -16,7 +16,7 @@ class CopyTaskNode : public TaskNode {
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void InferShapeOfBlobsInProducedRegsts(TaskGraph*) override;
void InferBlobDescInProducedRegsts(TaskGraph*) override;
};
class CopyHDTaskNode final : public CopyTaskNode {
......
......@@ -24,17 +24,16 @@ void DataCompTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
FwEnrollLbn2ModelAndTmpRegsts(); // model model_tmp data_tmp
}
void DataCompTaskNode::FwInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
void DataCompTaskNode::FwInferBlobDescInProducedRegsts(TaskGraph*) {
exec_gph().ConstTopoForEachNode([this](const ExecNode* node) {
node->op()->InferShape4FwBlobs(
node->GetMutShapePtr4BnInOpFunc(),
chain_node()->parallel_desc()->policy(), parallel_id(),
chain_node()->parallel_desc()->parallel_num());
node->op()->InferBlobDesc4FwBlobs(
node->GetBlobDesc4BnInOpFunc(), chain_node()->parallel_desc()->policy(),
parallel_id(), chain_node()->parallel_desc()->parallel_num());
});
if (IsLossNode()) {
auto out_regst = GetRelatedRegst(SoleOutEdge());
auto in_regst = GetRelatedRegst(SoleInEdge());
out_regst->CopyShapeFrom(in_regst.get());
out_regst->CopyBlobDescFrom(in_regst.get());
}
}
......@@ -173,20 +172,20 @@ void DataCompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
BpEnrollLbn2ProducedRegst();
}
void DataCompTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
void DataCompTaskNode::BpInferBlobDescInProducedRegsts(TaskGraph*) {
// in_diff_regst
auto in_diff_regst = GetProducedRegstDesc("in_diff");
auto in_regst = GetRelatedRegst(GetFwNode()->SoleInEdge());
in_diff_regst->CopyShapeFrom(in_regst.get());
in_diff_regst->CopyBlobDescFrom(in_regst.get());
// model_diff_regst
if (auto md_diff_regst = GetProducedRegstDesc("model_diff")) {
md_diff_regst->CopyShapeFrom(
md_diff_regst->CopyBlobDescFrom(
GetFwNode()->GetConsumedRegstDesc("model").get());
}
// activation_diff_regst
if (auto acti_diff_regst = GetProducedRegstDesc("activation_diff")) {
auto acti_regst = GetFwNode()->GetProducedRegstDesc("activation");
acti_diff_regst->CopyShapeFrom(acti_regst.get());
acti_diff_regst->CopyBlobDescFrom(acti_regst.get());
}
}
......
......@@ -41,12 +41,12 @@ class DataCompTaskNode final : public CompTaskNode {
private:
OVERRIDE_IF_FW_BP_FOR_FUNC(BuildExecAndEnrollLbn2Regsts);
OVERRIDE_IF_FW_BP_FOR_FUNC(InferShapeOfBlobsInProducedRegsts);
OVERRIDE_IF_FW_BP_FOR_FUNC(InferBlobDescInProducedRegsts);
using Lbn2NodeBnMap = HashMap<std::string, std::pair<ExecNode*, std::string>>;
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph);
void FwInferShapeOfBlobsInProducedRegsts(TaskGraph* gph);
void FwInferBlobDescInProducedRegsts(TaskGraph* gph);
void FwBuildFromUserOps(Lbn2NodeBnMap* lbn2producer,
Lbn2NodeBnMap* extern_in_lbn2consumer);
void FwSetExecNodeFromInRegst(const Lbn2NodeBnMap& extern_in_lbn2consumer);
......@@ -56,7 +56,7 @@ class DataCompTaskNode final : public CompTaskNode {
void FwEnrollLbn2ActivationRegst();
void FwEnrollLbn2ModelAndTmpRegsts();
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void BpInferShapeOfBlobsInProducedRegsts(TaskGraph*);
void BpInferBlobDescInProducedRegsts(TaskGraph*);
void BpBuildExecGraph();
void BpEnrollLbn2ProducedRegst();
void BpEnrollLbn2ActivationDiffRegst();
......
......@@ -4,14 +4,14 @@ namespace oneflow {
void ExecEdge::set_lbn(const std::string& lbn) { lbn_ = lbn; }
std::function<Shape*(const std::string&)> ExecNode::GetMutShapePtr4BnInOpFunc()
std::function<BlobDesc*(const std::string&)> ExecNode::GetBlobDesc4BnInOpFunc()
const {
return [this](const std::string& bn_in_op) -> Shape* {
return [this](const std::string& bn_in_op) -> BlobDesc* {
auto it = this->bn_in_op2regst_.find(bn_in_op);
if (it == this->bn_in_op2regst_.end()) { return nullptr; }
std::shared_ptr<RegstDesc> regst = it->second.lock();
const std::string& lbn = this->op()->Lbn4BnInOp(bn_in_op);
return regst->GetMutShapePtr(lbn);
return regst->GetMutBlobDesc(lbn);
};
}
......
......@@ -55,7 +55,7 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
return bn_in_op2regst_;
}
std::function<Shape*(const std::string&)> GetMutShapePtr4BnInOpFunc() const;
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOpFunc() const;
std::string VisualStr() const { return op_->op_name(); }
......
......@@ -22,11 +22,11 @@ void LossAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
mut_exec_gph().UpdateSourceAndSink();
}
void LossAccCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) {
void LossAccCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {
if (!chain_node()->op_vec().empty()) {
auto loss_regst = GetConsumedRegstDesc("loss");
auto loss_acc_regst = GetProducedRegstDesc("loss_acc");
loss_acc_regst->CopyShapeFrom(loss_regst.get());
loss_acc_regst->CopyBlobDescFrom(loss_regst.get());
}
}
......
......@@ -13,7 +13,7 @@ class LossAccCompTaskNode final : public CompTaskNode {
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override;
void InferBlobDescInProducedRegsts(TaskGraph* gph) override;
TaskType task_type() const override { return kLossAccCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<LossAccCompTaskNode>();
......
......@@ -20,7 +20,6 @@ void LossRecordCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
mut_exec_gph().UpdateSourceAndSink();
}
void LossRecordCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) {
}
void LossRecordCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {}
} // namespace oneflow
......@@ -13,7 +13,7 @@ class LossRecordCompTaskNode final : public CompTaskNode {
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override;
void InferBlobDescInProducedRegsts(TaskGraph* gph) override;
bool IsMeaningLess() const override {
return !GetConsumedRegstDesc("loss_acc");
}
......
......@@ -35,13 +35,13 @@ void MdDiffAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
mut_exec_gph().UpdateSourceAndSink();
}
void MdDiffAccCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) {
void MdDiffAccCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {
CHECK(IsFwNode());
if (!chain_node()->op_vec().empty()) {
std::shared_ptr<RegstDesc> in_regst = GetConsumedRegstDesc("model_diff");
std::shared_ptr<RegstDesc> out_regst =
GetProducedRegstDesc("model_diff_acc");
out_regst->CopyShapeFrom(in_regst.get());
out_regst->CopyBlobDescFrom(in_regst.get());
}
}
......
......@@ -19,7 +19,7 @@ class MdDiffAccCompTaskNode final : public CompTaskNode {
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override;
void InferBlobDescInProducedRegsts(TaskGraph* gph) override;
TaskType task_type() const override { return kMdDiffAccCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<MdDiffAccCompTaskNode>();
......
......@@ -32,7 +32,7 @@ void MdSaveCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
}
}
void MdSaveCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) {
void MdSaveCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {
CHECK(IsFwNode());
}
......
......@@ -22,7 +22,7 @@ class MdSaveCompTaskNode final : public CompTaskNode {
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override;
void InferBlobDescInProducedRegsts(TaskGraph* gph) override;
bool IsMeaningLess() const override { return !GetConsumedRegstDesc("model"); }
TaskType task_type() const override { return kMdSaveCompTask; }
......
......@@ -34,17 +34,17 @@ void MdUpdtCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
mut_exec_gph().UpdateSourceAndSink();
}
void MdUpdtCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) {
void MdUpdtCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {
CHECK(IsFwNode());
ExecNode* exec_node = exec_gph().SoleNode();
auto model_diffs_regst = GetConsumedRegstDesc("model_diffs");
Shape packed_model_diffs_shape({model_diffs_regst->CompElemCntOfAllBlob()});
exec_node->op()->InferShape4FwBlobs(
[&](const std::string& bn_in_op) -> Shape* {
BlobDesc packed_blob_desc = model_diffs_regst->CompPackedBlobDesc();
exec_node->op()->InferBlobDesc4FwBlobs(
[&](const std::string& bn_in_op) -> BlobDesc* {
if (bn_in_op == "model_diffs") {
return &packed_model_diffs_shape;
return &packed_blob_desc;
} else {
return exec_node->GetMutShapePtr4BnInOpFunc()(bn_in_op);
return exec_node->GetBlobDesc4BnInOpFunc()(bn_in_op);
}
},
kDataParallel, 0, 0);
......
......@@ -35,7 +35,7 @@ class MdUpdtCompTaskNode final : public CompTaskNode {
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override;
void InferBlobDescInProducedRegsts(TaskGraph* gph) override;
TaskType task_type() const override { return kMdUpdtCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<MdUpdtCompTaskNode>();
......
......@@ -33,10 +33,9 @@ void TaskGraph::BuildExecAndEnrollLbn2Regsts() {
[this](TaskNode* node) { node->BuildExecAndEnrollLbn2Regsts(this); });
}
void TaskGraph::InferShapeOfBlobsInProducedRegsts() {
TopoForEachNode([this](TaskNode* node) {
node->InferShapeOfBlobsInProducedRegsts(this);
});
void TaskGraph::InferBlobDescInProducedRegsts() {
TopoForEachNode(
[this](TaskNode* node) { node->InferBlobDescInProducedRegsts(this); });
}
std::vector<CompTaskNode*> TaskGraph::CompTasksInChain(const ChainNode* chain) {
......
......@@ -22,7 +22,7 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
const ChainGraph* chain_gph() const { return stage_gph_->chain_gph(); }
std::vector<CompTaskNode*> CompTasksInChain(const ChainNode*);
void InferShapeOfBlobsInProducedRegsts();
void InferBlobDescInProducedRegsts();
const std::string& name() const { return name_; }
......
......@@ -42,7 +42,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
//
virtual void BuildExecAndEnrollLbn2Regsts(TaskGraph*) = 0;
virtual void InferShapeOfBlobsInProducedRegsts(TaskGraph*) = 0;
virtual void InferBlobDescInProducedRegsts(TaskGraph*) = 0;
#define OVERRIDE_IF_FW_BP_FOR_FUNC(func_name) \
void func_name(TaskGraph* gph) override { \
......
......@@ -168,7 +168,7 @@ void Compiler::BuildLossGraph(
void Compiler::InferShape4Regsts() {
for (auto& task_gph : ordered_task_gphs_) {
LOG(INFO) << "InferShape for " << task_gph->name();
task_gph->InferShapeOfBlobsInProducedRegsts();
task_gph->InferBlobDescInProducedRegsts();
}
}
......
......@@ -31,13 +31,13 @@ std::string BoxingOp::obn2lbn(const std::string& output_bn) const {
return GetStringFromSpecialConf("lbn");
}
void BoxingOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void BoxingOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
auto boxing_conf = op_conf().boxing_conf();
auto in_box_case = boxing_conf.in_box_case();
std::vector<int64_t> data_tmp_blob_shape_vec =
GetShapePtr4BnInOp(input_bns().at(0))->dim_vec();
GetBlobDesc4BnInOp(input_bns().at(0))->shape().dim_vec();
// if it is a concat-box, accumulate the dimensions on concat-axis.
// otherwise only check all boxes are in the same shape.
......@@ -47,7 +47,8 @@ void BoxingOp::InferShape4FwBlobs(
CHECK(concat_axis == 0 || concat_axis == 1);
}
for (size_t ib_idx = 1; ib_idx < input_bns().size(); ++ib_idx) {
auto ib_shape_vec = GetShapePtr4BnInOp(input_bns().at(ib_idx))->dim_vec();
auto ib_shape_vec =
GetBlobDesc4BnInOp(input_bns().at(ib_idx))->shape().dim_vec();
for (size_t i = 0; i < ib_shape_vec.size(); ++i) {
if (in_box_case == BoxingOpConf::kConcatBox && i == concat_axis) {
data_tmp_blob_shape_vec[i] += ib_shape_vec[i];
......@@ -61,7 +62,8 @@ void BoxingOp::InferShape4FwBlobs(
// it is stored back if and only if this is a concat-clone box
if (in_box_case == BoxingOpConf::kConcatBox
&& out_box_case == BoxingOpConf::kCloneBox) {
*GetShapePtr4BnInOp(SoleDtbn()) = Shape(data_tmp_blob_shape_vec);
GetBlobDesc4BnInOp(SoleDtbn())->mut_shape() =
Shape(data_tmp_blob_shape_vec);
}
CHECK_NE(out_box_case, BoxingOpConf::OUT_BOX_NOT_SET);
if (out_box_case == BoxingOpConf::kDataSplitBox) {
......@@ -70,11 +72,12 @@ void BoxingOp::InferShape4FwBlobs(
auto output_shape_vec = data_tmp_blob_shape_vec;
for (size_t i = 0; i < out_num; ++i) {
output_shape_vec[0] = splitter.At(i).size();
*GetShapePtr4BnInOp(output_bns()[i]) = Shape(output_shape_vec);
GetBlobDesc4BnInOp(output_bns()[i])->mut_shape() =
Shape(output_shape_vec);
}
} else if (out_box_case == BoxingOpConf::kCloneBox) {
for (auto obn : output_bns()) {
*GetShapePtr4BnInOp(obn) = Shape(data_tmp_blob_shape_vec);
GetBlobDesc4BnInOp(obn)->mut_shape() = Shape(data_tmp_blob_shape_vec);
}
} else {
UNEXPECTED_RUN();
......
......@@ -14,8 +14,8 @@ class BoxingOp final : public SysOperator {
void InitFromOpConf(const OperatorConf& op_conf) override;
const PbMessage& GetSpecialConf() const override;
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -36,7 +36,7 @@ TEST(BoxingOp, box_4_10x5x6x6) {
};
// do infer shape
boxing_op->InferShape4FwBlobs(fp, kModelParallel, 0, 1);
boxing_op->InferBlobDesc4FwBlobs(fp, kModelParallel, 0, 1);
// test results
// output_shape should be:
......@@ -58,7 +58,7 @@ TEST(BoxingOp, box_4_10x5x6x6) {
boxing_op = ConstructOp(op_conf);
// do infer shape
boxing_op->InferShape4FwBlobs(fp, kModelParallel, 0, 1);
boxing_op->InferBlobDesc4FwBlobs(fp, kModelParallel, 0, 1);
// test results
// output shape should be the same as input
......@@ -75,7 +75,7 @@ TEST(BoxingOp, box_4_10x5x6x6) {
boxing_op = ConstructOp(op_conf);
// do infer shape
boxing_op->InferShape4FwBlobs(fp, kModelParallel, 0, 1);
boxing_op->InferBlobDesc4FwBlobs(fp, kModelParallel, 0, 1);
// data_tmp_shape is {10, 17, 6, 6}, and the 17 = 4 + 4 + 4 + 5
Shape* data_tmp_shape_ptr = bn2shape_ptr.at(boxing_op->SoleDtbn());
......
......@@ -16,12 +16,12 @@ const PbMessage& CloneOp::GetSpecialConf() const {
return op_conf().clone_conf();
}
void CloneOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void CloneOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
Shape* input_shape_ptr = GetShapePtr4BnInOp(SoleIbn());
const BlobDesc* input_blob_desc = GetBlobDesc4BnInOp(SoleIbn());
for (std::string obn : output_bns()) {
*GetShapePtr4BnInOp(obn) = *input_shape_ptr;
*GetBlobDesc4BnInOp(obn) = *input_blob_desc;
}
}
......
......@@ -15,8 +15,8 @@ class CloneOp final : public SysOperator {
void InitFromOpConf(const OperatorConf& op_conf) override;
const PbMessage& GetSpecialConf() const override;
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -18,7 +18,7 @@ TEST(CloneOp, clone_4x3_3_times) {
return bn2shape_ptr.at(bn);
};
clone_op->InferShape4FwBlobs(fp, kDataParallel, 3, 10);
clone_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 3, 10);
Shape* input_shape_ptr = bn2shape_ptr.at(clone_op->SoleIbn());
for (std::string obn : clone_op->output_bns()) {
......
......@@ -18,23 +18,26 @@ const PbMessage& ConcatOp::GetSpecialConf() const {
return op_conf().concat_conf();
}
void ConcatOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void ConcatOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
std::vector<int64_t> vec = GetShapePtr4BnInOp(input_bns().at(0))->dim_vec();
std::vector<int64_t> vec =
GetBlobDesc4BnInOp(input_bns().at(0))->shape().dim_vec();
for (size_t ibn_idx = 1; ibn_idx < input_bns().size(); ++ibn_idx) {
Shape* ib_shape = GetShapePtr4BnInOp(input_bns().at(ibn_idx));
const Shape& ib_shape =
GetBlobDesc4BnInOp(input_bns().at(ibn_idx))->shape();
int32_t concat_axis = op_conf().concat_conf().axis();
for (int64_t j = 0; j < ib_shape->NumAxes(); ++j) {
if (j == concat_axis || j == concat_axis + ib_shape->NumAxes()) {
vec[j] += ib_shape->At(j);
for (int64_t j = 0; j < ib_shape.NumAxes(); ++j) {
if (j == concat_axis || j == concat_axis + ib_shape.NumAxes()) {
vec[j] += ib_shape.At(j);
} else {
CHECK_EQ(vec[j], ib_shape->At(j));
CHECK_EQ(vec[j], ib_shape.At(j));
}
}
}
CHECK_EQ(vec.size(), GetShapePtr4BnInOp(input_bns().at(0))->NumAxes());
*GetShapePtr4BnInOp(SoleObn()) = Shape(vec);
CHECK_EQ(vec.size(),
GetBlobDesc4BnInOp(input_bns().at(0))->shape().NumAxes());
GetBlobDesc4BnInOp(SoleObn())->mut_shape() = Shape(vec);
}
REGISTER_OP(OperatorConf::kConcatConf, ConcatOp);
......
......@@ -15,8 +15,8 @@ class ConcatOp final : public UserOperator {
const PbMessage& GetSpecialConf() const override;
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -21,7 +21,7 @@ TEST(ConcatOp, concat_two_3x3) {
return bn2shape_ptr.at(bn);
};
// infershape
concat_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1);
concat_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1);
// test
Shape* output_shape_ptr = fp(concat_op->SoleObn());
ASSERT_EQ(*output_shape_ptr, Shape({3, 6}));
......
......@@ -22,15 +22,13 @@ const PbMessage& ConvolutionOp::GetSpecialConf() const {
return op_conf().convolution_conf();
}
void ConvolutionOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void ConvolutionOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
Shape* input_shape_ptr = GetShapePtr4BnInOp(SoleIbn());
Shape* output_shape_ptr = GetShapePtr4BnInOp(SoleObn());
Shape* colbuf_shape_ptr = GetShapePtr4BnInOp("col_buf");
const Shape& input_shape = GetBlobDesc4BnInOp(SoleIbn())->shape();
auto conv_conf = op_conf().convolution_conf();
int64_t batch_size = input_shape_ptr->At(0);
int64_t c_i = input_shape_ptr->At(1);
int64_t batch_size = input_shape.At(0);
int64_t c_i = input_shape.At(1);
int32_t out_num = GetInt32FromSpecialConf("out_num");
if (policy == kModelParallel) {
......@@ -43,32 +41,29 @@ void ConvolutionOp::InferShape4FwBlobs(
int64_t output_size = 1;
std::vector<int64_t> output_shape_vec = {batch_size, c_o};
int64_t h_len = (input_shape_ptr->At(2) + 2 * conv_conf.pad_h()
- conv_conf.kernel_size_h())
/ conv_conf.stride_h()
+ 1;
int64_t h_len =
(input_shape.At(2) + 2 * conv_conf.pad_h() - conv_conf.kernel_size_h())
/ conv_conf.stride_h()
+ 1;
output_shape_vec.push_back(h_len);
int64_t w_len = (input_shape_ptr->At(3) + 2 * conv_conf.pad_w()
- conv_conf.kernel_size_w())
/ conv_conf.stride_w()
+ 1;
int64_t w_len =
(input_shape.At(3) + 2 * conv_conf.pad_w() - conv_conf.kernel_size_w())
/ conv_conf.stride_w()
+ 1;
output_shape_vec.push_back(w_len);
kernel_size *= conv_conf.kernel_size_h();
kernel_size *= conv_conf.kernel_size_w();
output_size *= h_len;
output_size *= w_len;
*output_shape_ptr = Shape(output_shape_vec);
CHECK_EQ(output_shape_ptr->NumAxes(), input_shape_ptr->NumAxes());
*colbuf_shape_ptr = Shape({batch_size, output_size, c_i * kernel_size});
Shape* weight = GetShapePtr4BnInOp("weight");
*weight = Shape({c_o, c_i * kernel_size});
GetBlobDesc4BnInOp(SoleObn())->mut_shape() = Shape(output_shape_vec);
GetBlobDesc4BnInOp("col_buf")->mut_shape() =
Shape({batch_size, output_size, c_i * kernel_size});
GetBlobDesc4BnInOp("weight")->mut_shape() = Shape({c_o, c_i * kernel_size});
if (GetBoolFromSpecialConf("has_bias_term")) {
Shape* bias = GetShapePtr4BnInOp("bias");
Shape* biasmult_shape_ptr = GetShapePtr4BnInOp("bias_multiplier");
*bias = Shape({c_o});
*biasmult_shape_ptr = Shape({output_size});
GetBlobDesc4BnInOp("bias")->mut_shape() = Shape({c_o});
GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape({output_size});
}
}
......
......@@ -13,8 +13,8 @@ class ConvolutionOp final : public UserOperator {
void InitFromOpConf(const OperatorConf& op_conf) override;
const PbMessage& GetSpecialConf() const override;
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
void FixParallelDesc(ParallelDesc* pr_desc) const override {
......
......@@ -37,7 +37,7 @@ void TestDataParallelConvolutionOp() {
};
// infershape
convolution_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1);
convolution_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1);
// test
Shape* output_shape_ptr = fp(convolution_op->SoleObn());
......@@ -71,7 +71,7 @@ void TestModelParallelConvolutionOp() {
};
// infershape
convolution_op->InferShape4FwBlobs(fp, kModelParallel, 3, 8);
convolution_op->InferBlobDesc4FwBlobs(fp, kModelParallel, 3, 8);
// test
Shape* output_shape_ptr = fp(convolution_op->SoleObn());
......
......@@ -15,8 +15,8 @@ const PbMessage& DataLoaderOp::GetSpecialConf() const {
return op_conf().data_loader_conf();
}
void DataLoaderOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void DataLoaderOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
// useful vars
int32_t piece_size = JobDesc::Singleton()->piece_size();
......@@ -27,9 +27,9 @@ void DataLoaderOp::InferShape4FwBlobs(
feature_shape.insert(feature_shape.end(),
feature_shape_of_one_ins.dim_vec().begin(),
feature_shape_of_one_ins.dim_vec().end());
*GetShapePtr4BnInOp("feature") = Shape(feature_shape);
GetBlobDesc4BnInOp("feature")->mut_shape() = Shape(feature_shape);
// label shape
*GetShapePtr4BnInOp("label") = Shape({piece_size});
GetBlobDesc4BnInOp("label")->mut_shape() = Shape({piece_size});
}
REGISTER_OP(OperatorConf::kDataLoaderConf, DataLoaderOp);
......
......@@ -14,8 +14,8 @@ class DataLoaderOp final : public SysOperator {
void InitFromOpConf(const OperatorConf& op_conf) override;
const PbMessage& GetSpecialConf() const override;
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -21,10 +21,10 @@ const PbMessage& InnerProductOp::GetSpecialConf() const {
return op_conf().innerproduct_conf();
}
void InnerProductOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InnerProductOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
Shape* in_shape_ptr = GetShapePtr4BnInOp(SoleIbn());
const Shape& in_shape = GetBlobDesc4BnInOp(SoleIbn())->shape();
int32_t out_num = GetInt32FromSpecialConf("out_num");
if (policy == kModelParallel) {
BalancedSplitter splitter(out_num, parallel_num);
......@@ -32,22 +32,20 @@ void InnerProductOp::InferShape4FwBlobs(
}
// output bn
Shape* out_shape_ptr = GetShapePtr4BnInOp(SoleObn());
*out_shape_ptr = Shape({in_shape_ptr->At(0), out_num});
GetBlobDesc4BnInOp(SoleObn())->mut_shape() = Shape({in_shape.At(0), out_num});
// model bn
Shape* weight_shape_ptr = GetShapePtr4BnInOp("weight");
*weight_shape_ptr = Shape({out_num, in_shape_ptr->Count(1)});
GetBlobDesc4BnInOp("weight")->mut_shape() =
Shape({out_num, in_shape.Count(1)});
if (GetBoolFromSpecialConf("has_bias_term")) {
// model bn
Shape* bias_shape_ptr = GetShapePtr4BnInOp("bias");
*bias_shape_ptr = Shape({1, out_num});
GetBlobDesc4BnInOp("bias")->mut_shape() = Shape({1, out_num});
// model tmp bn
CHECK_EQ(model_tmp_bns().size(), 1);
Shape* bias_multiplier_shape_ptr = GetShapePtr4BnInOp("bias_multiplier");
*bias_multiplier_shape_ptr = Shape({in_shape_ptr->At(0), 1});
GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() =
Shape({in_shape.At(0), 1});
}
}
......
......@@ -13,8 +13,8 @@ class InnerProductOp final : public UserOperator {
void InitFromOpConf(const OperatorConf& op_conf) override;
const PbMessage& GetSpecialConf() const override;
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
void FixParallelDesc(ParallelDesc* pr_desc) const override {
......
......@@ -29,7 +29,7 @@ void TestModelParallelInnerProductOp(bool has_bias_term) {
return bn2shape_ptr.at(bn);
};
ip_op->InferShape4FwBlobs(fp, kModelParallel, 3, 10);
ip_op->InferBlobDesc4FwBlobs(fp, kModelParallel, 3, 10);
BalancedSplitter splitter(40, 10);
int out_num = splitter.At(3).size();
......@@ -70,7 +70,7 @@ void TestDataParallelInnerProductOp(bool has_bias_term) {
return bn2shape_ptr.at(bn);
};
ip_op->InferShape4FwBlobs(fp, kDataParallel, 3, 10);
ip_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 3, 10);
Shape* out_shape_ptr = bn2shape_ptr.at(ip_op->SoleObn());
CHECK_EQ(*out_shape_ptr, Shape({1000, 40}));
......
......@@ -10,8 +10,8 @@ class ModelUpdtOp : public SysOperator {
OF_DISALLOW_COPY_AND_MOVE(ModelUpdtOp);
virtual ~ModelUpdtOp() = default;
virtual void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
virtual void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override {
// do nothing
......
......@@ -15,11 +15,10 @@ const PbMessage& MomentumModelUpdateOp::GetSpecialConf() const {
return op_conf().momentum_mdupdt_conf();
}
void MomentumModelUpdateOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void MomentumModelUpdateOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
Shape* input_shape_ptr = GetShapePtr4BnInOp("model_diffs");
*GetShapePtr4BnInOp("momentum") = *input_shape_ptr;
TODO();
}
REGISTER_OP(OperatorConf::kMomentumMdupdtConf, MomentumModelUpdateOp);
......
......@@ -13,8 +13,8 @@ class MomentumModelUpdateOp final : public ModelUpdtOp {
void InitFromOpConf(const OperatorConf& op_conf) override;
const PbMessage& GetSpecialConf() const override;
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -16,11 +16,11 @@ const PbMessage& MultinomialLogisticLossOp::GetSpecialConf() const {
return op_conf().multinomial_logistic_loss_conf();
}
void MultinomialLogisticLossOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void MultinomialLogisticLossOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
*GetShapePtr4BnInOp("loss") = Shape({1});
*GetShapePtr4BnInOp("loss_buffer") = Shape({1});
GetBlobDesc4BnInOp("loss")->mut_shape() = Shape({1});
GetBlobDesc4BnInOp("loss_buffer")->mut_shape() = Shape({1});
}
REGISTER_OP(OperatorConf::kMultinomialLogisticLossConf,
......
......@@ -17,8 +17,8 @@ class MultinomialLogisticLossOp final : public UserOperator {
const PbMessage& GetSpecialConf() const override;
bool IsLossOp() const override { return true; }
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -21,7 +21,7 @@ TEST(MultinomialLogisticLossOp, test_loss_op) {
return bn2shape_ptr.at(bn);
};
loss_op->InferShape4FwBlobs(fp, kDataParallel, 2, 10);
loss_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 2, 10);
Shape* loss_shape_ptr = bn2shape_ptr.at(loss_op->SoleObn());
Shape* loss_buffer_shape_ptr = bn2shape_ptr.at(loss_op->SoleDtbn());
......
......@@ -2,13 +2,13 @@
#define ONEFLOW_CORE_OPERATOR_OPERATOR_H_
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/keyword.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/placement.pb.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/operator/operator.pb.h"
#include "oneflow/core/register/blob_desc.h"
namespace oneflow {
......@@ -79,8 +79,8 @@ class Operator {
// Read: shape of input_blobs
// Write: shape of output_blobs, model_blobs, data_tmp_blobs, model_tmp_blobs
virtual void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
virtual void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const = 0;
......@@ -144,8 +144,8 @@ class SysOperator : public Operator {
SysOperator() = default;
virtual ~SysOperator() = default;
virtual void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
virtual void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override {
UNEXPECTED_RUN();
......
......@@ -15,32 +15,30 @@ const PbMessage& PoolingOp::GetSpecialConf() const {
return op_conf().pooling_conf();
}
void PoolingOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void PoolingOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
Shape* input_shape_ptr = GetShapePtr4BnInOp(SoleIbn());
CHECK_EQ(input_shape_ptr->NumAxes(), 4);
Shape* output_shape_ptr = GetShapePtr4BnInOp(SoleObn());
const Shape& input_shape = GetBlobDesc4BnInOp(SoleIbn())->shape();
CHECK_EQ(input_shape.NumAxes(), 4);
BlobDesc* output_blob_desc = GetBlobDesc4BnInOp(SoleObn());
const PoolingOpConf& pooling_conf = op_conf().pooling_conf();
std::vector<int64_t> output_shape_dim_vec = {input_shape_ptr->At(0),
input_shape_ptr->At(1)};
std::vector<int64_t> output_shape_dim_vec = {input_shape.At(0),
input_shape.At(1)};
output_shape_dim_vec.push_back((input_shape_ptr->At(2)
+ 2 * pooling_conf.pad_h()
output_shape_dim_vec.push_back((input_shape.At(2) + 2 * pooling_conf.pad_h()
- pooling_conf.kernel_size_h())
/ pooling_conf.stride_h()
+ 1);
output_shape_dim_vec.push_back((input_shape_ptr->At(3)
+ 2 * pooling_conf.pad_w()
output_shape_dim_vec.push_back((input_shape.At(3) + 2 * pooling_conf.pad_w()
- pooling_conf.kernel_size_w())
/ pooling_conf.stride_w()
+ 1);
*output_shape_ptr = Shape(output_shape_dim_vec);
Shape* data_tmp_shape_ptr = GetShapePtr4BnInOp(SoleDtbn());
*data_tmp_shape_ptr = Shape(output_shape_dim_vec);
output_blob_desc->mut_shape() = Shape(output_shape_dim_vec);
BlobDesc* data_tmp_blob_desc = GetBlobDesc4BnInOp(SoleDtbn());
data_tmp_blob_desc->mut_shape() = Shape(output_shape_dim_vec);
}
REGISTER_OP(OperatorConf::kPoolingConf, PoolingOp);
......
......@@ -16,8 +16,8 @@ class PoolingOp final : public UserOperator {
void InitFromOpConf(const OperatorConf& op_conf) override;
const PbMessage& GetSpecialConf() const override;
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -30,7 +30,7 @@ TEST(PoolingOp, pool_100x64x11x11) {
return bn2shape_ptr.at(bn);
};
// do infer shape
pooling_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1);
pooling_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1);
// test
Shape* output_shape_ptr = bn2shape_ptr.at(pooling_op->SoleObn());
Shape* data_tmp_shape_ptr = bn2shape_ptr.at(pooling_op->SoleDtbn());
......
......@@ -15,8 +15,8 @@ class RecordOp final : public SysOperator {
const PbMessage& GetSpecialConf() const override;
bool IsRecordOp() const override { return true; }
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override {}
......
......@@ -14,12 +14,10 @@ const PbMessage& ReluOp::GetSpecialConf() const {
return op_conf().relu_conf();
}
void ReluOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void ReluOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
Shape* output_shape_ptr = GetShapePtr4BnInOp(SoleObn());
Shape* input_shape_ptr = GetShapePtr4BnInOp(SoleIbn());
*output_shape_ptr = *input_shape_ptr;
*GetBlobDesc4BnInOp(SoleObn()) = *GetBlobDesc4BnInOp(SoleIbn());
}
REGISTER_OP(OperatorConf::kReluConf, ReluOp);
......
......@@ -15,8 +15,8 @@ class ReluOp final : public UserOperator {
const PbMessage& GetSpecialConf() const override;
bool IsElemWise() const override { return true; }
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -17,7 +17,7 @@ TEST(ReluOp, relu_3x5x4) {
return bn2shape_ptr.at(bn);
};
// do infer shape
relu_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1);
relu_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1);
// test
Shape* input_shape_ptr = bn2shape_ptr.at(relu_op->SoleIbn());
Shape* output_shape_ptr = bn2shape_ptr.at(relu_op->SoleObn());
......
......@@ -15,11 +15,10 @@ const PbMessage& RMSPropModelUpdateOp::GetSpecialConf() const {
return op_conf().rmsprop_mdupdt_conf();
}
void RMSPropModelUpdateOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void RMSPropModelUpdateOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
Shape* input_shape_ptr = GetShapePtr4BnInOp("model_diffs");
*GetShapePtr4BnInOp("mean_square") = *input_shape_ptr;
TODO();
}
REGISTER_OP(OperatorConf::kRmspropMdupdtConf, RMSPropModelUpdateOp);
......
......@@ -13,8 +13,8 @@ class RMSPropModelUpdateOp final : public ModelUpdtOp {
void InitFromOpConf(const OperatorConf& op_conf) override;
const PbMessage& GetSpecialConf() const override;
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -17,16 +17,16 @@ const PbMessage& SoftmaxLossOp::GetSpecialConf() const {
return op_conf().softmax_loss_conf();
}
void SoftmaxLossOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void SoftmaxLossOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
const std::vector<int64_t> in_dim_vec =
GetShapePtr4BnInOp("prediction")->dim_vec();
GetBlobDesc4BnInOp("prediction")->shape().dim_vec();
CHECK_EQ(in_dim_vec.size(), 2);
CHECK_EQ(*GetShapePtr4BnInOp("label"), Shape({in_dim_vec[0]}));
*GetShapePtr4BnInOp(SoleObn()) = Shape({1});
*GetShapePtr4BnInOp("prob") = Shape(in_dim_vec);
*GetShapePtr4BnInOp("tmp_1D") = Shape({in_dim_vec[0]});
CHECK_EQ(GetBlobDesc4BnInOp("label")->shape(), Shape({in_dim_vec[0]}));
GetBlobDesc4BnInOp(SoleObn())->mut_shape() = Shape({1});
GetBlobDesc4BnInOp("prob")->mut_shape() = Shape(in_dim_vec);
GetBlobDesc4BnInOp("tmp_1D")->mut_shape() = Shape({in_dim_vec[0]});
}
REGISTER_OP(OperatorConf::kSoftmaxLossConf, SoftmaxLossOp);
......
......@@ -15,8 +15,8 @@ class SoftmaxLossOp final : public UserOperator {
const PbMessage& GetSpecialConf() const override;
bool IsLossOp() const override { return true; }
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -20,7 +20,7 @@ TEST(SoftmaxLossOp, softmax_loss_3x5) {
return bn2shape_ptr.at(bn);
};
// infershape
softmax_loss_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1);
softmax_loss_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1);
// test
ASSERT_EQ(*fp("loss"), Shape({1}));
ASSERT_EQ(*fp("prob"), Shape({3, 5}));
......
......@@ -15,13 +15,14 @@ const PbMessage& SoftmaxOp::GetSpecialConf() const {
return op_conf().softmax_conf();
}
void SoftmaxOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void SoftmaxOp::InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
std::vector<int64_t> vec = GetShapePtr4BnInOp(SoleIbn())->dim_vec();
const std::vector<int64_t>& vec =
GetBlobDesc4BnInOp(SoleIbn())->shape().dim_vec();
CHECK_EQ(vec.size(), 2);
*GetShapePtr4BnInOp(SoleObn()) = Shape(vec);
*GetShapePtr4BnInOp(SoleDtbn()) = Shape({vec[0]});
GetBlobDesc4BnInOp(SoleObn())->mut_shape() = Shape(vec);
GetBlobDesc4BnInOp(SoleDtbn())->mut_shape() = Shape({vec[0]});
}
REGISTER_OP(OperatorConf::kSoftmaxConf, SoftmaxOp);
......
......@@ -14,8 +14,8 @@ class SoftmaxOp final : public UserOperator {
void InitFromOpConf(const OperatorConf& op_conf) override;
const PbMessage& GetSpecialConf() const override;
void InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
void InferBlobDesc4FwBlobs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
ParallelPolicy policy, int64_t parallel_id,
int64_t parallel_num) const override;
......
......@@ -17,7 +17,7 @@ TEST(SoftmaxOp, softmax_3x5) {
return bn2shape_ptr.at(bn);
};
// infershape
softmax_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1);
softmax_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1);
// test
Shape* output_shape_ptr = fp(softmax_op->SoleObn());
Shape* tmp_shape_ptr = fp(softmax_op->SoleDtbn());
......
......@@ -6,7 +6,7 @@
namespace oneflow {
class Blob {
class Blob final {
public:
OF_DISALLOW_COPY_AND_MOVE(Blob);
Blob(void* dptr, const Shape* shape) : dptr_(dptr), shape_(shape) {}
......
#ifndef ONEFLOW_CORE_REGISTER_BLOB_DESC_H_
#define ONEFLOW_CORE_REGISTER_BLOB_DESC_H_
#include "oneflow/core/common/shape.h"
#include "oneflow/core/register/blob_desc.pb.h"
namespace oneflow {
class BlobDesc final {
public:
// OF_DISALLOW_COPY_AND_MOVE(BlobDesc);
BlobDesc() = default;
~BlobDesc() = default;
BlobDesc(const BlobDescProto& proto) { shape_ = Shape(proto.shape()); }
const Shape& shape() const { return shape_; }
Shape& mut_shape() { return shape_; }
void ToProto(BlobDescProto* proto) const {
shape_.ToProto(proto->mutable_shape());
}
private:
Shape shape_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_REGISTER_BLOB_DESC_H_
syntax = "proto3";
package oneflow;
import "oneflow/core/common/shape.proto";
message BlobDescProto {
ShapeProto shape = 1;
}
......@@ -40,57 +40,57 @@ void RegstDesc::AddConsumer(const TaskNode* new_consumer) {
}
void RegstDesc::CopyLbnFrom(const RegstDesc* rhs) {
lbn2shape_.clear();
for (const auto& pair : rhs->lbn2shape_) {
CHECK(lbn2blob_desc_.empty());
for (const auto& pair : rhs->lbn2blob_desc_) {
const std::string& lbn = pair.first;
auto shape = of_make_unique<Shape>();
CHECK(lbn2shape_.emplace(lbn, std::move(shape)).second);
CHECK(lbn2blob_desc_.emplace(lbn, of_make_unique<BlobDesc>()).second);
}
}
void RegstDesc::CopyShapeFrom(const RegstDesc* rhs) {
for (const auto& pair : lbn2shape_) {
void RegstDesc::CopyBlobDescFrom(const RegstDesc* rhs) {
for (const auto& pair : lbn2blob_desc_) {
const std::string& lbn = pair.first;
*(lbn2shape_.at(lbn)) = rhs->GetShape(lbn);
*(lbn2blob_desc_.at(lbn)) = rhs->GetBlobDesc(lbn);
}
}
void RegstDesc::EnrollLbn(const std::string& lbn) {
std::unique_ptr<Shape> ptr(new Shape);
CHECK(lbn2shape_.emplace(lbn, std::move(ptr)).second) << lbn;
CHECK(lbn2blob_desc_.emplace(lbn, of_make_unique<BlobDesc>()).second) << lbn;
}
const Shape& RegstDesc::GetShape(const std::string& lbn) const {
return *(lbn2shape_.at(lbn));
const BlobDesc& RegstDesc::GetBlobDesc(const std::string& lbn) const {
return *(lbn2blob_desc_.at(lbn));
}
Shape* RegstDesc::GetMutShapePtr(const std::string& lbn) {
return lbn2shape_.at(lbn).get();
BlobDesc* RegstDesc::GetMutBlobDesc(const std::string& lbn) {
return lbn2blob_desc_.at(lbn).get();
}
void RegstDesc::ForEachLbn(std::function<void(const std::string&)> func) const {
for (const auto& p : lbn2shape_) { func(p.first); }
for (const auto& p : lbn2blob_desc_) { func(p.first); }
}
void RegstDesc::EraseZeroSizeBlob() {
EraseIf<std::string, std::unique_ptr<Shape>>(
&lbn2shape_,
[](HashMap<std::string, std::unique_ptr<Shape>>::iterator it) {
return it->second->elem_cnt() == 0;
EraseIf<std::string, std::unique_ptr<BlobDesc>>(
&lbn2blob_desc_,
[](HashMap<std::string, std::unique_ptr<BlobDesc>>::iterator it) {
return it->second->shape().elem_cnt() == 0;
});
}
int64_t RegstDesc::CompElemCntOfAllBlob() const {
int64_t sum = 0;
for (const auto& pair : lbn2shape_) { sum += pair.second->elem_cnt(); }
for (const auto& pair : lbn2blob_desc_) {
sum += pair.second->shape().elem_cnt();
}
return sum;
}
std::string RegstDesc::DebugStr() const {
std::stringstream ss;
ss << "{";
for (const auto& pair : lbn2shape_) {
ss << "{" << pair.first << ":" << pair.second->DebugStr() << "}";
for (const auto& pair : lbn2blob_desc_) {
ss << "{" << pair.first << ":" << pair.second->shape().DebugStr() << "}";
}
ss << "}";
return ss.str();
......@@ -104,10 +104,10 @@ void RegstDesc::ToProto(RegstDescProto* ret) const {
ret->add_consumer_task_id(consumer->task_id());
}
}
for (const auto& pair : lbn2shape_) {
PbMapPair<std::string, ShapeProto> pb_pair(pair.first);
for (const auto& pair : lbn2blob_desc_) {
PbMapPair<std::string, BlobDescProto> pb_pair(pair.first);
pair.second->ToProto(&(pb_pair.second));
ret->mutable_lbn2shape()->insert(pb_pair);
ret->mutable_lbn2blob_desc()->insert(pb_pair);
}
ret->set_register_num(register_num_);
*(ret->mutable_mem_case()) = InferMemCase();
......@@ -140,4 +140,9 @@ MemoryCase RegstDesc::InferMemCase() const {
return mem_case;
}
BlobDesc RegstDesc::CompPackedBlobDesc() const {
BlobDesc packed_blob_desc;
packed_blob_desc.mut_shape() = Shape({CompElemCntOfAllBlob()});
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_
#define ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/register/blob_desc.h"
#include "oneflow/core/register/register_desc.pb.h"
namespace oneflow {
// Regst : Register
// Contig : Contiguous
class TaskNode;
class RegstDesc final {
......@@ -27,28 +23,29 @@ class RegstDesc final {
const HashSet<const TaskNode*>& consumers() const { return consumers_; }
void AddConsumer(const TaskNode*);
// Lbn and Shape
// Lbn and BlobDesc
void CopyLbnFrom(const RegstDesc*);
void CopyShapeFrom(const RegstDesc*);
void CopyBlobDescFrom(const RegstDesc*);
void EnrollLbn(const std::string& lbn);
const Shape& GetShape(const std::string& lbn) const;
Shape* GetMutShapePtr(const std::string& lbn);
const BlobDesc& GetBlobDesc(const std::string& lbn) const;
BlobDesc* GetMutBlobDesc(const std::string& lbn);
void ForEachLbn(std::function<void(const std::string&)> func) const;
size_t NumOfLbn() const { return lbn2shape_.size(); }
size_t NumOfLbn() const { return lbn2blob_desc_.size(); }
//
void EraseZeroSizeBlob();
int64_t CompElemCntOfAllBlob() const;
std::string DebugStr() const;
void ToProto(RegstDescProto*) const;
MemoryCase InferMemCase() const;
BlobDesc CompPackedBlobDesc() const;
private:
int64_t CompElemCntOfAllBlob() const;
int64_t regst_desc_id_;
const TaskNode* producer_;
HashSet<const TaskNode*> consumers_;
HashMap<std::string, std::unique_ptr<Shape>> lbn2shape_;
HashMap<std::string, std::unique_ptr<BlobDesc>> lbn2blob_desc_;
int64_t register_num_;
};
......
syntax = "proto3";
package oneflow;
import "oneflow/core/common/shape.proto";
import "oneflow/core/register/blob_desc.proto";
import "oneflow/core/memory/memory_case.proto";
message RegstDescProto {
int64 regst_desc_id = 1;
int64 producer_task_id = 2;
repeated int64 consumer_task_id = 3;
map<string, ShapeProto> lbn2shape = 4;
map<string, BlobDescProto> lbn2blob_desc = 4;
int64 register_num = 5;
MemoryCase mem_case = 6;
}
......@@ -19,12 +19,12 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
}
int64_t elem_cnt = 0;
std::vector<std::string> lbns;
lbns.reserve(regst_desc_proto.lbn2shape().size());
for (const auto& pair : regst_desc_proto.lbn2shape()) {
const Shape* shape_ptr =
runtime_regst_desc->GetShapePtrFromLbn(pair.first);
lbns.reserve(regst_desc_proto.lbn2blob_desc().size());
for (const auto& pair : regst_desc_proto.lbn2blob_desc()) {
const Shape& shape_ptr =
runtime_regst_desc->GetBlobDescFromLbn(pair.first)->shape();
lbns.push_back(pair.first);
elem_cnt += shape_ptr->elem_cnt();
elem_cnt += shape_ptr.elem_cnt();
}
std::sort(lbns.begin(), lbns.end());
std::pair<char*, std::function<void()>> allocation =
......@@ -33,7 +33,8 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
int64_t blob_idx = 0;
for (const std::string& lbn : lbns) {
const Shape* shape_ptr = runtime_regst_desc->GetShapePtrFromLbn(lbn);
const Shape* shape_ptr =
&(runtime_regst_desc->GetBlobDescFromLbn(lbn)->shape());
auto blob_ptr =
of_make_unique<Blob>(allocation.first + blob_idx, shape_ptr);
CHECK(regst->lbn2blob_.emplace(lbn, std::move(blob_ptr)).second);
......
......@@ -15,8 +15,9 @@ RtRegstDesc::RtRegstDesc(const RegstDescProto& regst_desc_proto) {
consumers_actor_id_.push_back(IDMgr::Singleton()->ActorId4TaskId(task_id));
}
for (const auto& pair : regst_desc_proto.lbn2shape()) {
CHECK(lbn2shape_.emplace(pair.first, of_make_unique<Shape>(pair.second))
for (const auto& pair : regst_desc_proto.lbn2blob_desc()) {
CHECK(lbn2blob_desc_
.emplace(pair.first, of_make_unique<BlobDesc>(pair.second))
.second);
}
mem_case_ = regst_desc_proto.mem_case();
......
#ifndef ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_
#define ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_
#include "oneflow/core/common/shape.h"
#include "oneflow/core/memory/memory_case.pb.h"
#include "oneflow/core/register/blob_desc.h"
#include "oneflow/core/register/register_desc.pb.h"
namespace oneflow {
......@@ -23,15 +23,15 @@ class RtRegstDesc {
int64_t register_num() const { return register_num_; }
const MemoryCase& mem_case() const { return mem_case_; }
const Shape* GetShapePtrFromLbn(const std::string& lbn) const {
return lbn2shape_.at(lbn).get();
const BlobDesc* GetBlobDescFromLbn(const std::string& lbn) const {
return lbn2blob_desc_.at(lbn).get();
}
private:
int64_t regst_desc_id_;
int64_t producer_actor_id_;
std::vector<int64_t> consumers_actor_id_;
std::unordered_map<std::string, std::unique_ptr<Shape>> lbn2shape_;
std::unordered_map<std::string, std::unique_ptr<BlobDesc>> lbn2blob_desc_;
int64_t register_num_;
MemoryCase mem_case_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册