提交 88a40be1 编写于 作者: W willzhang4a58

concat op

上级 aab7c593
......@@ -9,6 +9,9 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(oneflow_src_dir ${PROJECT_SOURCE_DIR}/oneflow)
set(oneflow_cmake_dir ${PROJECT_SOURCE_DIR}/cmake)
# option
option(download_third_party "Download and build all third party codes" OFF)
# Force to link static cxx runtime library
if (MSVC)
foreach(flag_var
......@@ -62,7 +65,9 @@ RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS
add_library(of_protoobj ${PROTO_SRCS} ${PROTO_HDRS})
target_link_libraries(of_protoobj ${oneflow_third_party_libs})
add_dependencies(of_protoobj ${oneflow_third_party_dependencies})
if (download_third_party)
add_dependencies(of_protoobj ${oneflow_third_party_dependencies})
endif()
# cc obj lib
include_directories(${oneflow_src_dir})
......
......@@ -120,7 +120,7 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
// Construct Op
OperatorConf op_conf;
op_conf.set_name("TODO");
BoxingOpConf* box_conf = op_conf.mutable_boxing_op_conf();
BoxingOpConf* box_conf = op_conf.mutable_boxing_conf();
box_conf->set_lbn(lbn);
box_conf->set_in_num(sorted_in_edges.size());
box_conf->set_out_num(sorted_out_edges.size());
......
......@@ -87,8 +87,8 @@ void LogicalGraph::CollectCloneInfos(
// Construct clone op
OperatorConf pb_op_conf;
pb_op_conf.set_name("clone_" + lbn + "_" + cur_node->node_id_str());
pb_op_conf.mutable_clone_op_conf()->set_out_num(edges.size());
pb_op_conf.mutable_clone_op_conf()->set_lbn(lbn);
pb_op_conf.mutable_clone_conf()->set_out_num(edges.size());
pb_op_conf.mutable_clone_conf()->set_lbn(lbn);
auto clone_op = ConstructOpFromPbConf(pb_op_conf);
// Set clone_info
CloneInfo clone_info;
......
......@@ -9,12 +9,12 @@ void SetModelLoadChain(ChainNode* model_load_chain) {
// model load op
OperatorConf op_conf;
op_conf.set_name("");
op_conf.mutable_model_load_op_conf();
op_conf.mutable_model_load_conf();
model_load_chain->mut_op_vec() = {ConstructOpFromPbConf(op_conf)};
// model load parallel_conf
ParallelConf pr_conf;
pr_conf.set_policy(kDataParallel);
pr_conf.add_devices(JobDesc::Singleton().MdLoadMachine() + "/disk");
pr_conf.add_devices(JobDesc::Singleton().md_load_machine() + "/disk");
model_load_chain->mut_parallel_desc().reset(new ParallelDesc(pr_conf));
// output
model_load_chain->mut_output_lbns() = {RegstDesc::kAllLbn};
......
......@@ -8,12 +8,12 @@ void SetModelSaveChain(ChainNode* model_save_chain) {
// model save op
OperatorConf op_conf;
op_conf.set_name("");
op_conf.mutable_model_save_op_conf();
op_conf.mutable_model_save_conf();
model_save_chain->mut_op_vec() = {ConstructOpFromPbConf(op_conf)};
// model save parallel_conf
ParallelConf pr_conf;
pr_conf.set_policy(kDataParallel);
pr_conf.add_devices(JobDesc::Singleton().MdSaveMachine() + "/disk");
pr_conf.add_devices(JobDesc::Singleton().md_save_machine() + "/disk");
model_save_chain->mut_parallel_desc().reset(new ParallelDesc(pr_conf));
// output
model_save_chain->mut_input_lbns() = {RegstDesc::kAllLbn};
......
......@@ -18,7 +18,7 @@ void MdUpdtTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
// Construct ModelUpdateOp
OperatorConf op_conf;
op_conf.set_name("model_update_" + data_chain->ConcatedOpsName());
op_conf.mutable_model_update_op_conf();
op_conf.mutable_model_update_conf();
auto model_update_op = ConstructOpFromPbConf(op_conf);
// ModelUpdateChain
auto chain_gph = make_unique<ChainGraph> ();
......
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "job/id_manager.h"
#include "graph/task_graph_manager.h"
#include "job/job_conf.pb.h"
DEFINE_string(job_user_conf_filepath, "", "");
using oneflow::JobConf;
using oneflow::JobDesc;
using oneflow::IDMgr;
using oneflow::TaskGraphMgr;
DEFINE_string(job_conf_filepath, "", "");
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
oneflow::JobUserConf job_user_conf;
ParseProtoFromTextFile(FLAGS_job_user_conf_filepath, &job_user_conf);
oneflow::JobDesc::Singleton().InitFromJobUserConf(job_user_conf);
oneflow::IDMgr::Singleton().InitFromResource(oneflow::JobDesc::Singleton().resource());
oneflow::TaskGraphMgr::Singleton().Init();
google::InitGoogleLogging(argv[0]);
google::ParseCommandLineFlags(&argc, &argv, true);
JobConf job_conf;
ParseProtoFromTextFile(FLAGS_job_conf_filepath, &job_conf);
JobDesc::Singleton().InitFromJobConf(job_conf);
IDMgr::Singleton().InitFromResource(JobDesc::Singleton().resource());
TaskGraphMgr::Singleton().Init();
return 0;
}
......@@ -5,8 +5,10 @@ import "job/dlnet_conf.proto";
import "job/resource.proto";
import "job/strategy.proto";
message JobUserConf {
message JobConf {
string train_dlnet_conf_filepath = 1;
string resource_filepath = 2;
string strategy_filepath = 3;
string model_load_machine = 4;
string model_save_machine = 5;
}
#include "job/job_desc.h"
#include "common/proto_io.h"
namespace oneflow {
void JobDesc::InitFromJobConf(const JobConf& conf) {
ParseProtoFromTextFile(conf.train_dlnet_conf_filepath(), &train_dl_net_conf_);
ParseProtoFromTextFile(conf.resource_filepath(), &resource_);
ParseProtoFromTextFile(conf.strategy_filepath(), &strategy_);
md_load_machine_ = conf.model_load_machine();
md_save_machine_ = conf.model_save_machine();
}
void JobDesc::InitFromProto(const JobDescProto&) {
TODO();
}
JobDescProto JobDesc::ToProto() const {
TODO();
}
} // namespace oneflow
......@@ -17,20 +17,27 @@ class JobDesc final {
return obj;
}
void InitFromJobUserConf(const JobUserConf&) { TODO(); }
void InitFromProto(const JobDescProto&) { TODO(); }
JobDescProto ToProto() const { TODO(); }
void InitFromJobConf(const JobConf&);
void InitFromProto(const JobDescProto&);
JobDescProto ToProto() const;
const DLNetConf& train_dlnet_conf() const { TODO(); }
const Resource& resource() const { TODO(); }
const Strategy& strategy() const { TODO(); }
// Getters
const DLNetConf& train_dlnet_conf() const { return train_dl_net_conf_; }
const Resource& resource() const { return resource_; }
const Strategy& strategy() const { return strategy_; }
const std::string& MdLoadMachine() { TODO(); }
const std::string& MdSaveMachine() { TODO(); }
const std::string& md_load_machine() { return md_load_machine_; }
const std::string& md_save_machine() { return md_save_machine_; }
private:
JobDesc() = default;
DLNetConf train_dl_net_conf_;
Resource resource_;
Strategy strategy_;
std::string md_load_machine_;
std::string md_save_machine_;
};
} // namespace oneflow
......
......@@ -13,7 +13,7 @@ enum DeviceType {
}
message Resource {
repeated Machine machines = 1;
repeated Machine machine = 1;
int64 device_num_per_machine = 2;
DeviceType device_type = 3;
}
......
......@@ -5,8 +5,8 @@ namespace oneflow {
void BoxingOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_boxing_op_conf());
auto cnf = new BoxingOpConf(op_conf.boxing_op_conf());
CHECK(op_conf.has_boxing_conf());
auto cnf = new BoxingOpConf(op_conf.boxing_conf());
mut_pb_op_conf().reset(cnf);
for (int64_t i = 0; i < cnf->in_num(); ++i) {
......
......@@ -5,8 +5,8 @@ namespace oneflow {
void CloneOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_clone_op_conf());
auto cnf = new CloneOpConf(op_conf.clone_op_conf());
CHECK(op_conf.has_clone_conf());
auto cnf = new CloneOpConf(op_conf.clone_conf());
mut_pb_op_conf().reset(cnf);
EnrollInputBn("in");
......
#include "operator/concat_op.h"
namespace oneflow {
void ConcatOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_concat_conf());
auto cnf = new ConcatOpConf(op_conf.concat_conf());
mut_pb_op_conf().reset(cnf);
for (int i = 0; i < cnf->in_size(); ++i) {
EnrollInputBn("in_" + std::to_string(i));
}
EnrollOutputBn("out");
}
} // namespace oneflow
#ifndef ONEFLOW_OPERATOR_CONCAT_OP_H_
#define ONEFLOW_OPERATOR_CONCAT_OP_H_
#include "operator/operator.h"
namespace oneflow {
class ConcatOp final : public UserOperator {
public:
OF_DISALLOW_COPY_AND_MOVE(ConcatOp);
ConcatOp() = default;
~ConcatOp() = default;
void Init(const OperatorConf& op_conf) override;
std::string normal_ibn2lbn(const std::string& input_bn) const override { TODO(); }
void InferShape4ObAndDtbFromIb() const override { TODO(); }
void InferShape4Mtb() const override { TODO(); }
void InferShape4Mdb() const override { TODO(); }
private:
};
} // namespace oneflow
#endif // ONEFLOW_OPERATOR_CONCAT_OP_H_
......@@ -6,8 +6,8 @@ namespace oneflow {
void ConvolutionOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_convolution_op_conf());
auto cnf = new ConvolutionOpConf(op_conf.convolution_op_conf());
CHECK(op_conf.has_convolution_conf());
auto cnf = new ConvolutionOpConf(op_conf.convolution_conf());
mut_pb_op_conf().reset(cnf);
EnrollInputBn("in");
......
......@@ -5,8 +5,8 @@ namespace oneflow {
void CopyOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_copy_op_conf());
auto cnf = new CopyOpConf(op_conf.copy_op_conf());
CHECK(op_conf.has_copy_conf());
auto cnf = new CopyOpConf(op_conf.copy_conf());
mut_pb_op_conf().reset(cnf);
for (int64_t i = 0; i < cnf->copied_lbns_size(); ++i) {
......
......@@ -6,8 +6,8 @@ namespace oneflow {
void DataLoaderOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_data_loader_op_conf());
auto cnf = new DataLoaderOpConf(op_conf.data_loader_op_conf());
CHECK(op_conf.has_data_loader_conf());
auto cnf = new DataLoaderOpConf(op_conf.data_loader_conf());
mut_pb_op_conf().reset(cnf);
EnrollOutputBn("data", false);
......
......@@ -6,8 +6,8 @@ namespace oneflow {
void InnerProductOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_inner_product_op_conf());
auto cnf = new InnerProductOpConf(op_conf.inner_product_op_conf());
CHECK(op_conf.has_inner_product_conf());
auto cnf = new InnerProductOpConf(op_conf.inner_product_conf());
mut_pb_op_conf().reset(cnf);
EnrollInputBn("in");
......
......@@ -6,9 +6,9 @@ namespace oneflow {
void MultinomialLogisticLossOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_multinomial_logistic_loss_op_conf());
CHECK(op_conf.has_multinomial_logistic_loss_conf());
auto cnf = new MultinomialLogisticLossOpConf(
op_conf.multinomial_logistic_loss_op_conf());
op_conf.multinomial_logistic_loss_conf());
mut_pb_op_conf().reset(cnf);
EnrollInputBn("data");
......
......@@ -37,6 +37,12 @@ message MultinomialLogisticLossOpConf {
string loss = 3;
}
message ConcatOpConf {
repeated string in = 1;
string out = 2;
int64 axis = 3;
}
message CopyOpConf {
enum CopyType {
H2D = 0;
......@@ -46,6 +52,7 @@ message CopyOpConf {
repeated string copied_lbns = 2;
}
message CloneOpConf {
int64 out_num = 1;
string lbn = 2;
......@@ -85,18 +92,19 @@ message ModelSaveOpConf {
message OperatorConf {
string name = 1;
oneof specified_type {
ConvolutionOpConf convolution_op_conf = 100;
InnerProductOpConf inner_product_op_conf = 101;
DataLoaderOpConf data_loader_op_conf = 102;
PoolingOpConf pooling_op_conf = 103;
ReluOpConf relu_op_conf = 104;
SoftmaxOpConf softmax_op_conf = 105;
MultinomialLogisticLossOpConf multinomial_logistic_loss_op_conf = 106;
CopyOpConf copy_op_conf = 107;
CloneOpConf clone_op_conf = 108;
BoxingOpConf boxing_op_conf = 109;
ModelUpdateOpConf model_update_op_conf = 110;
ModelLoadOpConf model_load_op_conf = 111;
ModelSaveOpConf model_save_op_conf = 112;
ConvolutionOpConf convolution_conf = 100;
InnerProductOpConf inner_product_conf = 101;
DataLoaderOpConf data_loader_conf = 102;
PoolingOpConf pooling_conf = 103;
ReluOpConf relu_conf = 104;
SoftmaxOpConf softmax_conf = 105;
MultinomialLogisticLossOpConf multinomial_logistic_loss_conf = 106;
CopyOpConf copy_conf = 107;
CloneOpConf clone_conf = 108;
BoxingOpConf boxing_conf = 109;
ModelUpdateOpConf model_update_conf = 110;
ModelLoadOpConf model_load_conf = 111;
ModelSaveOpConf model_save_conf = 112;
ConcatOpConf concat_conf = 113;
}
}
......@@ -15,31 +15,31 @@ std::shared_ptr<Operator> OperatorFactory::ConstructOp(
const OperatorConf& op_conf) const {
std::shared_ptr<Operator> ret;
switch (op_conf.specified_type_case()) {
case OperatorConf::kConvolutionOpConf: {
case OperatorConf::kConvolutionConf: {
ret.reset(new ConvolutionOp);
break;
}
case OperatorConf::kInnerProductOpConf: {
case OperatorConf::kInnerProductConf: {
ret.reset(new InnerProductOp);
break;
}
case OperatorConf::kDataLoaderOpConf: {
case OperatorConf::kDataLoaderConf: {
ret.reset(new DataLoaderOp);
break;
}
case OperatorConf::kPoolingOpConf: {
case OperatorConf::kPoolingConf: {
ret.reset(new PoolingOp);
break;
}
case OperatorConf::kReluOpConf: {
case OperatorConf::kReluConf: {
ret.reset(new ReluOp);
break;
}
case OperatorConf::kSoftmaxOpConf: {
case OperatorConf::kSoftmaxConf: {
ret.reset(new SoftmaxOp);
break;
}
case OperatorConf::kMultinomialLogisticLossOpConf: {
case OperatorConf::kMultinomialLogisticLossConf: {
ret.reset(new MultinomialLogisticLossOp);
break;
}
......
......@@ -6,8 +6,8 @@ namespace oneflow {
void PoolingOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_pooling_op_conf());
auto cnf = new PoolingOpConf(op_conf.pooling_op_conf());
CHECK(op_conf.has_pooling_conf());
auto cnf = new PoolingOpConf(op_conf.pooling_conf());
mut_pb_op_conf().reset(cnf);
EnrollInputBn("in");
......
......@@ -6,8 +6,8 @@ namespace oneflow {
void ReluOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_relu_op_conf());
auto cnf = new ReluOpConf(op_conf.relu_op_conf());
CHECK(op_conf.has_relu_conf());
auto cnf = new ReluOpConf(op_conf.relu_conf());
mut_pb_op_conf().reset(cnf);
EnrollInputBn("in");
......
......@@ -6,8 +6,8 @@ namespace oneflow {
void SoftmaxOp::Init(const OperatorConf& op_conf) {
mut_op_name() = op_conf.name();
CHECK(op_conf.has_softmax_op_conf());
auto cnf = new SoftmaxOpConf(op_conf.softmax_op_conf());
CHECK(op_conf.has_softmax_conf());
auto cnf = new SoftmaxOpConf(op_conf.softmax_conf());
mut_pb_op_conf().reset(cnf);
EnrollInputBn("in");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册