提交 a68583b3 编写于 作者: S ShawnXuan 提交者: Jinhui Yuan

alexnet, resnet and inception v3 generator (#993)

* alexnet, resnet and inception v3 generator

* move net gen apps to tools
上级 4a0f6219
# main cpp
list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/oneflow/core/job/oneflow.cpp)
list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/tools/gen_resnet.cpp)
list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/tools/gen_alexnet.cpp)
list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/tools/gen_googlenet.cpp)
function(oneflow_add_executable)
if (BUILD_CUDA)
......
#include "oneflow/core/job/init_op_conf.h"
#define INIT_OP_CONF(op_conf_type, mutable_op_conf_obj) \
OperatorConf* op = Global<JobConf1>::Get()->mutable_net()->add_op(); \
op->set_name(name); \
op_conf_type* op_conf = op->mutable_op_conf_obj(); \
op_conf->set_in(in); \
op_conf->set_out("out");
namespace oneflow {
void InitPlacementGroup() {
PlacementGroup* placement_group =
Global<JobConf1>::Get()->mutable_placement()->add_placement_group();
CHECK_EQ(Global<JobConf1>::Get()->placement().placement_group_size(), 1);
ParallelConf* parallel_conf = new ParallelConf();
parallel_conf->set_policy(ParallelPolicy::kDataParallel);
parallel_conf->add_device_name("192.168.1.13:cpu:0-3");
placement_group->set_allocated_parallel_conf(parallel_conf);
}
void AddOpToPlacementGroup(const std::string& name) {
PlacementGroup* pg = Global<JobConf1>::Get()->mutable_placement()->mutable_placement_group(0);
pg->mutable_op_set()->add_op_name(name);
}
// TBD: support 1D, 2D, 3D with macro
std::string Conv2D(const std::string& name, const std::string& in, const int filters,
const std::string& padding, const std::string& data_format,
std::vector<int> kernel_size, const int strides, const int dilation_rate,
const bool use_bias, ActivationType activation) {
INIT_OP_CONF(Conv2DOpConf, mutable_conv_2d_conf)
op_conf->set_filters(filters);
op_conf->set_padding(padding);
op_conf->set_data_format(data_format);
for (std::vector<int>::iterator it = kernel_size.begin(); it != kernel_size.end(); ++it)
op_conf->add_kernel_size(*it);
// op_conf->add_kernel_size(kernel_size);
op_conf->add_strides(strides);
op_conf->add_strides(strides);
op_conf->add_dilation_rate(dilation_rate);
op_conf->add_dilation_rate(dilation_rate);
op_conf->set_use_bias(use_bias);
// op_conf->set_activation(activation);
// InitializerConf* weight_initializer = new InitializerConf();
// InitInitializerConf(weight_initializer, InitializerConf::kRandomNormalConf,
// InitInitializerConf(weight_initializer, InitializerConf::kMsraConf, 0.0, 1.0);
// op_conf->set_allocated_weight_initializer(weight_initializer);
/* InitializerConf* bias_initializer = new InitializerConf();
InitInitializerConf(bias_initializer, InitializerConf::kRandomUniformConf,
0.0, 1.0);
op_conf->set_allocated_bias_initializer(bias_initializer);*/
/*InitInitializerConf(weight_initializer, InitializerConf::kMsraConf,
0.0, 1.0); op_conf->set_allocated_weight_initializer(weight_initializer);
*/
if (use_bias) {
InitializerConf* bias_initializer = new InitializerConf();
InitInitializerConf(bias_initializer, InitializerConf::kConstantConf, 0.0, 1.0);
op_conf->set_allocated_bias_initializer(bias_initializer);
}
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
// TBD: support average, max and 1D, 2D, 3D with macro
std::string MaxPooling2D(const std::string& name, const std::string& in, const int pool_size,
const int strides, const std::string& padding,
const std::string& data_format) {
INIT_OP_CONF(MaxPooling2DOpConf, mutable_max_pooling_2d_conf)
op_conf->set_padding(padding);
op_conf->set_data_format(data_format);
op_conf->add_pool_size(pool_size);
op_conf->add_pool_size(pool_size);
op_conf->add_strides(strides);
op_conf->add_strides(strides);
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
std::string Dropout(const std::string& name, const std::string& in, const double rate) {
INIT_OP_CONF(DropoutOpConf, mutable_dropout_conf)
op_conf->set_rate(rate);
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
std::string LocalResponseNormalization(const std::string& name, const std::string& in,
const int depth_radius, const double bias,
const double alpha, const double beta) {
INIT_OP_CONF(LocalResponseNormalizationOpConf, mutable_local_response_normalization_conf)
op_conf->set_depth_radius(depth_radius);
op_conf->set_bias(bias);
op_conf->set_alpha(alpha);
op_conf->set_beta(beta);
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
std::string FullyConnected(const std::string& name, const std::string& in, const int units,
bool use_bias) {
INIT_OP_CONF(FullyConnectedOpConf, mutable_fully_connected_conf)
op_conf->set_units(units);
op_conf->set_use_bias(use_bias);
InitializerConf* weight_initializer = new InitializerConf();
// InitInitializerConf(weight_initializer, InitializerConf::kRandomNormalConf,
InitInitializerConf(weight_initializer, InitializerConf::kMsraConf, 0.0, 1.0);
op_conf->set_allocated_weight_initializer(weight_initializer);
if (use_bias) {
InitializerConf* bias_initializer = new InitializerConf();
InitInitializerConf(bias_initializer, InitializerConf::kConstantConf, 0.0, 1.0);
op_conf->set_allocated_bias_initializer(bias_initializer);
}
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
std::string AveragePooling2D(const std::string& name, const std::string& in, const int pool_size,
const int strides, const std::string& padding,
const std::string& data_format) {
INIT_OP_CONF(AveragePooling2DOpConf, mutable_average_pooling_2d_conf)
op_conf->set_padding(padding);
op_conf->set_data_format(data_format);
op_conf->add_pool_size(pool_size);
op_conf->add_pool_size(pool_size);
op_conf->add_strides(strides);
op_conf->add_strides(strides);
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
std::string BatchNorm(const std::string& name, const std::string& in, ActivationType activation,
int32_t axis, float momentum, float epsilon, bool center, bool scale,
float beta_init, float gamma_init, float mean_init, float variance_init) {
INIT_OP_CONF(NormalizationOpConf, mutable_normalization_conf)
// op_conf->set_momentum(momentum);
// op_conf->set_epsilon(epsilon);
// op_conf->set_center(center);
// op_conf->set_scale(scale);
// op_conf->set_beta_init(beta_init);
// op_conf->set_gamma_init(gamma_init);
// op_conf->set_mean_init(mean_init);
// op_conf->set_variance_init(variance_init);
op_conf->set_axis(axis);
op_conf->set_activation(activation);
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
std::string Relu(const std::string& name, const std::string& in) {
INIT_OP_CONF(ReluOpConf, mutable_relu_conf)
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
std::string Softmax(const std::string& name, const std::string& in, const int axis) {
INIT_OP_CONF(SoftmaxOpConf, mutable_softmax_conf)
op_conf->set_axis(axis);
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
std::string Add(const std::string& name, const std::vector<std::string>& ins,
ActivationType activation) {
OperatorConf* op = Global<JobConf1>::Get()->mutable_net()->add_op();
op->set_name(name);
AddOpConf* op_conf = op->mutable_add_conf();
for (auto it = ins.begin(); it != ins.end(); ++it) { op_conf->add_in(*it); }
op_conf->set_out("out");
op_conf->set_activation(activation);
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
std::string Concat(const std::string& name, const std::vector<std::string>& ins, const int axis) {
OperatorConf* op = Global<JobConf1>::Get()->mutable_net()->add_op();
op->set_name(name);
ConcatOpConf* op_conf = op->mutable_concat_conf();
for (auto it = ins.begin(); it != ins.end(); ++it) { op_conf->add_in(*it); }
op_conf->set_out("out");
op_conf->set_axis(axis);
AddOpToPlacementGroup(name);
return name + "/" + "out";
}
void InitInitializerConf(InitializerConf* initializer, const InitializerConf::TypeCase& type_case,
const float param1, const float param2) {
switch (type_case) {
case InitializerConf::kConstantConf: {
ConstantInitializerConf* constant_conf = new ConstantInitializerConf();
constant_conf->set_value(param1);
initializer->set_allocated_constant_conf(constant_conf);
break;
}
case InitializerConf::kConstantIntConf: {
ConstantIntInitializerConf* constant_int_conf = new ConstantIntInitializerConf();
constant_int_conf->set_value(static_cast<int>(param1));
initializer->set_allocated_constant_int_conf(constant_int_conf);
break;
}
case InitializerConf::kRandomUniformConf: {
RandomUniformInitializerConf* random_uniform_conf = new RandomUniformInitializerConf();
random_uniform_conf->set_min(param1);
random_uniform_conf->set_max(param2);
initializer->set_allocated_random_uniform_conf(random_uniform_conf);
break;
}
case InitializerConf::kRandomUniformIntConf: {
RandomUniformIntInitializerConf* random_uniform_int_conf =
new RandomUniformIntInitializerConf();
random_uniform_int_conf->set_min(static_cast<int>(param1));
random_uniform_int_conf->set_max(static_cast<int>(param2));
initializer->set_allocated_random_uniform_int_conf(random_uniform_int_conf);
break;
}
case InitializerConf::kRandomNormalConf: {
RandomNormalInitializerConf* random_normal_conf = new RandomNormalInitializerConf();
random_normal_conf->set_mean(param1);
random_normal_conf->set_std(param2);
initializer->set_allocated_random_normal_conf(random_normal_conf);
break;
}
case InitializerConf::kXavierConf: {
XavierInitializerConf* xavier_conf = new XavierInitializerConf();
xavier_conf->set_variance_norm(static_cast<VarianceNorm>(static_cast<int>(param1)));
initializer->set_allocated_xavier_conf(xavier_conf);
break;
}
case InitializerConf::kMsraConf: {
MsraInitializerConf* msra_conf = new MsraInitializerConf();
msra_conf->set_variance_norm(static_cast<VarianceNorm>(static_cast<int>(param1)));
initializer->set_allocated_msra_conf(msra_conf);
break;
}
case InitializerConf::TYPE_NOT_SET: {
LOG(INFO) << "InitializerConf::TYPE_NOT_SET";
break;
}
}
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_INIT_OP_CONF_H_
#define ONEFLOW_CORE_OPERATOR_INIT_OP_CONF_H_
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/job_conf.pb.h"
#include "oneflow/core/operator/op_conf.pb.h"
namespace oneflow {
void InitPlacementGroup();
void AddOpToPlacementGroup(const std::string& name);
std::string Conv2D(const std::string& name, const std::string& in, const int filters,
const std::string& padding = "same",
const std::string& data_format = "channels_last",
std::vector<int> kernel_size = {3, 3}, const int strides = 1,
const int dilation_rate = 1, const bool use_bias = true,
ActivationType activation = kNone);
std::string MaxPooling2D(const std::string& name, const std::string& in, const int pool_size = 3,
const int strides = 1, const std::string& padding = "same",
const std::string& data_format = "channels_last");
std::string Dropout(const std::string& name, const std::string& in, const double rate = 0.5);
std::string LocalResponseNormalization(const std::string& name, const std::string& in,
const int depth_radius = 5, const double bias = 1.0,
const double alpha = 1.0, const double beta = 0.5);
std::string FullyConnected(const std::string& name, const std::string& in, const int units = 1000,
bool use_bias = true);
std::string AveragePooling2D(const std::string& name, const std::string& in,
const int pool_size = 3, const int strides = 1,
const std::string& padding = "same",
const std::string& data_format = "channels_last");
std::string BatchNorm(const std::string& name, const std::string& in,
ActivationType activation = kNone, int32_t axis = 1, float momentum = 0.99,
float epsilon = 0.001, bool center = true, bool scale = true,
float beta_init = 0.0, float gamma_init = 1.0, float mean_init = 0.0,
float variance_init = 1.0);
std::string Relu(const std::string& name, const std::string& in);
std::string Softmax(const std::string& name, const std::string& in, const int axis = -1);
std::string Add(const std::string& name, const std::vector<std::string>& ins,
ActivationType activation = kNone);
std::string Concat(const std::string& name, const std::vector<std::string>& ins,
const int axis = 0);
void InitInitializerConf(InitializerConf* initializer, const InitializerConf::TypeCase& type_case,
const float param1, const float param2 = 0.0);
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_INIT_OP_CONF_H_
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <google/protobuf/text_format.h>
#include "oneflow/core/persistence/persistent_out_stream.h"
#include "oneflow/core/job/init_op_conf.h"
namespace oneflow {
std::string AlexNetConv2DBlock(const std::string& name, const std::string in, const int filters,
const std::string& padding = "same",
const std::string data_format = "channels_last",
const int kernel_size = 3, const int strides = 1,
const int dilation_rate = 1, const bool use_bias = true,
bool with_lrn = false, bool with_pooling = false) {
std::string op_name, op_out;
op_out = Conv2D(name, in, filters, padding, data_format, {kernel_size, kernel_size}, strides,
dilation_rate, use_bias);
op_name = "relu_" + name;
op_out = Relu(op_name, op_out);
if (with_lrn) {
op_name = "lrn_" + name;
op_out = LocalResponseNormalization(op_name, op_out, 2, 2.0, 1e-4,
0.75); // add lrn support here
}
if (with_pooling) {
op_name = "pool_" + name;
op_out = MaxPooling2D(op_name, op_out, 3, 2);
}
return op_out;
}
std::string AlexNetFeature(const std::string& in) {
std::string op_out = in;
// features
op_out = AlexNetConv2DBlock("conv1", op_out, 64 /*filter number*/, "same", "channels_last",
11 /*kernel_size*/, 4 /*stride*/, 1 /*dilation_rate*/,
true /*use_bias*/, true /*with_lrn*/, true /*with_pooling*/);
op_out = AlexNetConv2DBlock("conv2", op_out, 192 /*filter number*/, "same", "channels_last",
5 /*kernel_size*/, 1 /*stride*/, 1 /*dilation_rate*/,
true /*use_bias*/, true /*with_lrn*/, true /*with_pooling*/);
op_out = AlexNetConv2DBlock("conv3", op_out, 384 /*filter number*/, "same", "channels_last",
3 /*kernel_size*/, 1 /*stride*/, 1 /*dilation_rate*/,
true /*use_bias*/, false /*with_lrn*/, false /*with_pooling*/);
op_out = AlexNetConv2DBlock("conv4", op_out, 256 /*filter number*/, "same", "channels_last",
3 /*kernel_size*/, 1 /*stride*/, 1 /*dilation_rate*/,
true /*use_bias*/, false /*with_lrn*/, false /*with_pooling*/);
op_out = AlexNetConv2DBlock("conv5", op_out, 256 /*filter number*/, "same", "channels_last",
3 /*kernel_size*/, 1 /*stride*/, 1 /*dilation_rate*/,
true /*use_bias*/, false /*with_lrn*/, true /*with_pooling*/);
return op_out;
}
std::string Classifier(const std::string& in, const int num_classes, bool with_dropout = false,
int first_fc_layer_num = 6) {
std::string op_out, name;
if (with_dropout) {
name = "drop_" + std::to_string(first_fc_layer_num);
op_out = Dropout(name, in, 0.5);
} else {
op_out = in;
}
// classifier
// nn.Dropout(),
// nn.Linear(256 * 6 * 6, 4096)
name = "fc" + std::to_string(first_fc_layer_num);
op_out = FullyConnected(name, op_out, 4096);
// nn.ReLU(inplace=True),
name = "relu_" + name;
op_out = Relu(name, op_out);
// nn.Dropout(),
if (with_dropout) {
name = "drop_" + std::to_string(first_fc_layer_num + 1);
op_out = Dropout(name, op_out, 0.5);
}
// nn.Linear(4096, 4096),
name = "fc" + std::to_string(first_fc_layer_num + 1);
op_out = FullyConnected(name, op_out, 4096);
// nn.ReLU(inplace=True),
name = "relu_" + name;
op_out = Relu(name, op_out);
// nn.Linear(4096, num_classes),
name = "fc" + std::to_string(first_fc_layer_num + 2);
op_out = FullyConnected(name, op_out, num_classes);
return op_out;
}
void GenAlexNet(const int groups) {
Global<JobConf1>::New();
InitPlacementGroup();
LOG(INFO) << "Create AlexNet. groups = " << groups;
std::string op_out;
op_out = AlexNetFeature("feature");
op_out = Classifier(op_out, 1000, true);
op_out = Softmax("prob", op_out);
PrintProtoToTextFile(Global<JobConf1>::Get()->net(), "./alexnet.prototxt");
PrintProtoToTextFile(Global<JobConf1>::Get()->placement(), "./alexnet_placement.prototxt");
Global<JobConf1>::Delete();
}
} // namespace oneflow
DEFINE_int32(groups, 1, "groups number 1 or 2");
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true);
oneflow::GenAlexNet(FLAGS_groups);
return 0;
}
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <google/protobuf/text_format.h>
#include "oneflow/core/persistence/persistent_out_stream.h"
#include "oneflow/core/job/init_op_conf.h"
#define DATA_FORMAT "channels_first"
#define CONCAT_AXIS 1
namespace oneflow {
std::string ConvBnReluBlock(const std::string& name, const std::string& in, int filters,
const std::vector<int> kernel_size, int strides,
const std::string& padding) {
std::string op_out;
op_out = Conv2D(name + "_conv", in, filters, padding, DATA_FORMAT, kernel_size, strides, 1, false,
kNone);
op_out = BatchNorm(name + "_bn", op_out, kRelu);
// op_out = Relu(name + "_relu", op_out);
return op_out;
}
std::string InceptionBlock1(const std::string& name, const std::string& in, const int num1x1,
const std::vector<int> num5x5, const std::vector<int> num3x3d,
const int numPool) {
std::string out, branch1x1, branch5x5;
std::string branch3x3d;
std::string branchPool;
// 1x1 Convolution
branch1x1 = ConvBnReluBlock(name + "_b1x1", in, num1x1, {1, 1}, 1, "same");
// 5x5 Convolution
out = ConvBnReluBlock(name + "_b5x5-1", in, num5x5[0], {1, 1}, 1, "same");
branch5x5 = ConvBnReluBlock(name + "_b5x5", out, num5x5[1], {5, 5}, 1, "same");
// Double 3x3 Convolution
out = ConvBnReluBlock(name + "_b3x3d-1", in, num3x3d[0], {1, 1}, 1, "same");
out = ConvBnReluBlock(name + "_b3x3d-2", out, num3x3d[1], {3, 3}, 1, "same");
branch3x3d = ConvBnReluBlock(name + "_b3x3d", out, num3x3d[2], {3, 3}, 1, "same");
// Average Pooling
out = AveragePooling2D(name + "_pool_bPool", in, 3, 1, "same", DATA_FORMAT);
branchPool = ConvBnReluBlock(name + "_bPool", out, numPool, {1, 1}, 1, "same");
out = Concat(name + "_concat", {branch1x1, branch5x5, branch3x3d, branchPool}, CONCAT_AXIS);
return out;
}
std::string InceptionBlock2(const std::string& name, const std::string& in, const int num3x3,
const std::vector<int> num3x3d) {
std::string out;
std::string branch3x3, branch3x3d;
// 3x3 Convolution
branch3x3 = ConvBnReluBlock(name + "_b3x3", in, num3x3, {3, 3}, 2, "valid");
// Double 3x3 Convolution
out = ConvBnReluBlock(name + "_b3x3d-1", in, num3x3d[0], {1, 1}, 1, "same");
out = ConvBnReluBlock(name + "_b3x3d-2", out, num3x3d[1], {3, 3}, 1, "same");
branch3x3d = ConvBnReluBlock(name + "_b3x3d", out, num3x3d[2], {3, 3}, 2, "valid");
// Max Pooling
out = MaxPooling2D(name + "_pool", in, 3, 2, "valid", DATA_FORMAT);
out = Concat(name + "_concat", {branch3x3, branch3x3d, out}, CONCAT_AXIS);
return out;
}
std::string InceptionBlock3(const std::string& name, const std::string& in, const int num1x1,
const std::vector<int> num7x7, const std::vector<int> num7x7d,
const int numPool) {
std::string out, branch1x1, branch7x7;
std::string branch7x7d;
std::string branchPool;
// 1x1 Convolution
branch1x1 = ConvBnReluBlock(name + "_b1x1", in, num1x1, {1, 1}, 1, "same");
// 7x7 Convolution
out = ConvBnReluBlock(name + "_b7x7-1", in, num7x7[0], {1, 1}, 1, "same");
out = ConvBnReluBlock(name + "_b7x7-2", out, num7x7[1], {1, 7}, 1, "same");
branch7x7 = ConvBnReluBlock(name + "_b7x7", out, num7x7[2], {7, 1}, 1, "same");
// Double 7x7 Convolution
out = ConvBnReluBlock(name + "_b7x7d-1", in, num7x7d[0], {1, 1}, 1, "same");
out = ConvBnReluBlock(name + "_b7x7d-2", out, num7x7d[1], {7, 1}, 1, "same");
out = ConvBnReluBlock(name + "_b7x7d-3", out, num7x7d[2], {1, 7}, 1, "same");
out = ConvBnReluBlock(name + "_b7x7d-4", out, num7x7d[3], {7, 1}, 1, "same");
branch7x7d = ConvBnReluBlock(name + "_b7x7d", out, num7x7d[4], {1, 7}, 1, "same");
// Average Pooling
out = AveragePooling2D(name + "_pool_bPool", in, 3, 1, "same", DATA_FORMAT);
branchPool = ConvBnReluBlock(name + "_bPool", out, numPool, {1, 1}, 1, "same");
out = Concat(name + "_concat", {branch1x1, branch7x7, branch7x7d, branchPool}, CONCAT_AXIS);
return out;
}
std::string InceptionBlock4(const std::string& name, const std::string& in,
const std::vector<int> num3x3, const std::vector<int> num7x7_3x3) {
std::string out;
std::string branch3x3, branch7x7_3x3;
// 3x3 Convolution
out = ConvBnReluBlock(name + "_b3x3-1", in, num3x3[0], {1, 1}, 1, "same");
branch3x3 = ConvBnReluBlock(name + "_b3x3", out, num3x3[1], {3, 3}, 2, "valid");
// 7x7 3x3 Convolution
out = ConvBnReluBlock(name + "_b7x7-3x3-1", in, num7x7_3x3[0], {1, 1}, 1, "same");
out = ConvBnReluBlock(name + "_b7x7-3x3-2", out, num7x7_3x3[1], {1, 7}, 1, "same");
out = ConvBnReluBlock(name + "_b7x7-3x3-3", out, num7x7_3x3[2], {7, 1}, 1, "same");
branch7x7_3x3 = ConvBnReluBlock(name + "_b7x7-3x3", out, num7x7_3x3[3], {3, 3}, 2, "valid");
// Max Pooling
out = MaxPooling2D(name + "_pool", in, 3, 2, "valid", DATA_FORMAT);
out = Concat(name + "_concat", {branch3x3, branch7x7_3x3, out}, CONCAT_AXIS);
return out;
}
std::string InceptionBlock5(const std::string& name, const std::string& in, const int num1x1,
const std::vector<int> num3x3, const std::vector<int> num3x3_3x3,
const int numPool, const std::string& pool_type) {
std::string out, branch1x1;
std::string branch3x3, branch3x3_2, branch3x3_3;
std::string branch3x3_3x3, branch3x3_3x3_2, branch3x3_3x3_3;
// 1x1 Convolution
branch1x1 = ConvBnReluBlock(name + "_b1x1", in, num1x1, {1, 1}, 1, "same");
// 3x3 Convolution
out = ConvBnReluBlock(name + "_b3x3-1", in, num3x3[0], {1, 1}, 1, "same");
branch3x3_2 = ConvBnReluBlock(name + "_b3x3-2", out, num3x3[1], {1, 3}, 1, "same");
branch3x3_3 = ConvBnReluBlock(name + "_b3x3-3", out, num3x3[2], {3, 1}, 1, "same");
branch3x3 = Concat(name + "_concat3x3", {branch3x3_2, branch3x3_3}, CONCAT_AXIS);
// 3x3 3x3 Convolution
out = ConvBnReluBlock(name + "_b3x3-3x3-1", in, num3x3_3x3[0], {1, 1}, 1, "same");
out = ConvBnReluBlock(name + "_b3x3-3x3-2", out, num3x3_3x3[1], {3, 3}, 1, "same");
branch3x3_3x3_2 = ConvBnReluBlock(name + "_b3x3_3x3_3", out, num3x3_3x3[2], {1, 3}, 1, "same");
branch3x3_3x3_3 = ConvBnReluBlock(name + "_b3x3_3x3_4", out, num3x3_3x3[3], {3, 1}, 1, "same");
branch3x3_3x3 = Concat(name + "_concat3x3-3x3", {branch3x3_3x3_2, branch3x3_3x3_3}, CONCAT_AXIS);
// Pooling
if (pool_type == "avg")
out = AveragePooling2D(name + "_pool", in, 3, 1, "same", DATA_FORMAT);
else // if (pool_type == "max")
out = MaxPooling2D(name + "_pool", in, 3, 1, "same", DATA_FORMAT);
out = ConvBnReluBlock(name + "_bPool", out, numPool, {1, 1}, 1, "same");
out = Concat(name + "_concat", {branch1x1, branch3x3, branch3x3_3x3, out}, CONCAT_AXIS);
return out;
}
void InceptionV3Model() {
std::string out;
std::string mixed8;
// 299 x 299 x 3
out = ConvBnReluBlock("conv1", "feature/out", 32, {3, 3}, 2, "valid");
// 149 x 149 x 32
out = ConvBnReluBlock("conv2", out, 32, {3, 3}, 1, "valid");
// 147 x 147 x 32
out = ConvBnReluBlock("conv3", out, 64, {3, 3}, 1, "same");
// 147 x 147 x 64
out = MaxPooling2D("pool1", out, 3, 2, "valid", DATA_FORMAT);
// 73 x 73 x 64
out = ConvBnReluBlock("conv4", out, 80, {1, 1}, 1, "valid");
// 73 x 73 x 80
out = ConvBnReluBlock("conv5", out, 192, {3, 3}, 1, "valid");
// 71 x 71 x 192
out = MaxPooling2D("pool2", out, 3, 2, "valid", DATA_FORMAT);
// 35 x 35 x 192
// Inception Blocks
out = InceptionBlock1("mixed1", out, 64, {48, 64}, {64, 96, 96}, 32);
// 35 x 35 x 256
out = InceptionBlock1("mixed2", out, 64, {48, 64}, {64, 96, 96}, 64);
// 35 x 35 x 288
out = InceptionBlock1("mixed3", out, 64, {48, 64}, {64, 96, 96}, 64);
// 35 x 35 x 288
out = InceptionBlock2("mixed4", out, 384, {64, 96, 96});
// 17 x 17 x 768
out = InceptionBlock3("mixed5", out, 192, {128, 128, 192}, {128, 128, 128, 128, 192}, 192);
// 17 x 17 x 768
out = InceptionBlock3("mixed6", out, 192, {160, 160, 192}, {160, 160, 160, 160, 192}, 192);
// 17 x 17 x 768
out = InceptionBlock3("mixed7", out, 192, {128, 128, 192}, {160, 160, 160, 160, 192}, 192);
// 17 x 17 x 768
mixed8 = InceptionBlock3("mixed8", out, 192, {128, 128, 192}, {192, 192, 192, 192, 192}, 192);
// 17 x 17 x 768
out = InceptionBlock4("mixed9", mixed8, {192, 320}, {192, 192, 192, 192});
// 8 x 8 x 1280
out = InceptionBlock5("mixed10", out, 320, {384, 384, 384}, {448, 384, 384, 384}, 192, "avg");
// 8 x 8 x 2048
out = InceptionBlock5("mixed11", out, 320, {384, 384, 384}, {448, 384, 384, 384}, 192, "max");
// 8 x 8 x 2048
// Prediction
out = AveragePooling2D("pool3", out, 8, 1, "valid", DATA_FORMAT);
// 1 x 1 x 2048
out = Dropout("drop", out, 0.5);
// 1 x 1 x 2048
out = FullyConnected("fc1000", out, 1000);
// Auxiliary
// 17 x 17 x 768
out = AveragePooling2D("auxPool", mixed8, 5, 3, "valid", DATA_FORMAT);
// 5 x 5 x 768
out = ConvBnReluBlock("auxConv1", out, 128, {1, 1}, 1, "same");
// 5 x 5 x 128
out = ConvBnReluBlock("auxConv2", out, 768, {5, 5}, 1, "valid");
// 1 x 1 x 768
out = FullyConnected("auxFc", out, 1024);
}
void GenGoogLeNet() {
Global<JobConf1>::New();
InitPlacementGroup();
LOG(INFO) << "Create GoogLeNet/Inception V3.";
InceptionV3Model();
PrintProtoToTextFile(Global<JobConf1>::Get()->net(), "./googlenet.prototxt");
PrintProtoToTextFile(Global<JobConf1>::Get()->placement(), "./googlenet_placement.prototxt");
Global<JobConf1>::Delete();
}
} // namespace oneflow
// DEFINE_int32(groups, 1, "groups number 1 or 2");
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true);
oneflow::GenGoogLeNet();
return 0;
}
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <google/protobuf/text_format.h>
#include "oneflow/core/persistence/persistent_out_stream.h"
#include "oneflow/core/job/init_op_conf.h"
#define DATA_FORMAT "channels_first"
namespace oneflow {
// conv + batch_norm + relu
std::string Conv2DBlock(bool use_relu, const std::string& name, const std::string& in,
const int filters, const std::string& padding = "same",
const std::string& data_format = DATA_FORMAT, const int kernel_size = 3,
const int strides = 1, const int dilation_rate = 1,
const bool use_bias = false) {
std::string op_name, op_out;
op_out = Conv2D(name, in, filters, padding, data_format, {kernel_size, kernel_size}, strides,
dilation_rate, use_bias);
if (use_relu) {
op_name = "bn_" + name;
op_out = BatchNorm(op_name, op_out, kRelu, 1, 0.997,
0.0000101); // out of last op is the input of this
} else {
op_name = "bn_" + name;
op_out = BatchNorm(op_name, op_out, kNone, 1, 0.997,
0.0000101); // out of last op is the input of this
}
return op_out;
}
// one resnet contains 4 residual blocks, id from 2 to 5
// res id: residual block id, from 2 to 5.
// residual block name: res(res id)
// one residual block contains some building blocks
// building block id in residual block.
// building_block_name: res*_bb(building block id)
// e.g. res2a(caffe) -> res2_bb1(of), res3c -> res3_bb3
// one building block contains 3 conv blocks(>50), 2 conv blocks(18, 34)
// conv_block_name: res*_bb*_b1a/b/c(main branch), res*_bb*_b2(shortcut branch)
std::string BuildingBlock(const std::string& res_block_name, int building_block_id,
const std::string& in, int filter1_2, int filter3,
bool down_sampling = true) {
std::string op_out, building_block_name, name;
std::string b2_out = in;
int stride = 1;
if (building_block_id == 1 && down_sampling) stride = 2;
building_block_name = res_block_name + "_bb" + std::to_string(building_block_id);
// shortcup branch - b2
if (building_block_id == 1) {
name = building_block_name + "_b2";
b2_out = Conv2DBlock(false /*no relu*/, name, in, filter3, "same", DATA_FORMAT, 1, stride);
}
// main branch - b1
name = building_block_name + "_b1a";
op_out = Conv2DBlock(true, name, in, filter1_2, "same", DATA_FORMAT, 1, stride);
name = building_block_name + "_b1b";
op_out = Conv2DBlock(true, name, op_out, filter1_2, "same", DATA_FORMAT, 3, 1);
name = building_block_name + "_b1c";
op_out = Conv2DBlock(false /*no relu*/, name, op_out, filter3, "same", DATA_FORMAT, 1, 1);
// element wise sum
std::vector<std::string> v = {op_out, b2_out};
name = building_block_name + "_add";
op_out = Add(name, v, kRelu);
return op_out;
}
std::string ResidualBlock(int res_block_id, int building_block_num, const std::string& in,
int filter1_2, int filter3, bool down_sampling = true) {
std::string res_block_name = "res" + std::to_string(res_block_id);
std::string op_out = in;
bool ds;
for (int i = 0; i < building_block_num; ++i) {
if (i == 0 && down_sampling && res_block_id > 2)
ds = true;
else
ds = false;
op_out = BuildingBlock(res_block_name, i + 1, op_out, filter1_2, filter3, ds);
}
return op_out;
}
std::string ResidualBlocks(const int layer_num, const std::string& in) {
std::string op_out = in;
// layer number -> building block number array(residual block 2, 3, 4, 5)
HashMap<int, std::vector<int>> layer_num2bb_num = {
{50, {3, 4, 6, 3}}, {101, {3, 4, 23, 3}}, {152, {3, 8, 36, 3}},
};
// filter num of each residual block
int res_block_filter_num[4][2] = {
{64, 256}, {128, 512}, {256, 1024}, {512, 2048},
};
std::vector<int> bb_num = layer_num2bb_num[layer_num];
for (int i = 0; i < 4; i++) {
op_out = ResidualBlock(i + 2, bb_num[i], op_out, res_block_filter_num[i][0],
res_block_filter_num[i][1], i > 0);
}
return op_out;
}
void FindAndReplace(std::string& source, std::string const& find, std::string const& replace) {
for (std::string::size_type i = 0; (i = source.find(find, i)) != std::string::npos;) {
source.replace(i, find.length(), replace);
i += replace.length();
}
}
void DLNet2csv(const DLNetConf net) {
PersistentOutStream out_stream(LocalFS(), "./resnet.csv");
out_stream << "name,type,inputs,outputs,params\n";
std::string str;
for (const OperatorConf& cur_op_conf : net.op()) {
// std::string str;
// google::protobuf::TextFormat::PrintToString(cur_op_conf, &str);
// LOG(INFO) << str;
out_stream << cur_op_conf.name() << ",";
// out_stream << std::to_string(cur_op_conf.op_type_case()) << ",";
// out_stream << cur_op_conf.in() << ",";
// out_stream << cur_op_conf.out() << ",";
google::protobuf::TextFormat::PrintToString(cur_op_conf, &str);
FindAndReplace(str, "{", "");
FindAndReplace(str, "}", "");
FindAndReplace(str, "\n", ",");
out_stream << str;
out_stream << "\n";
}
}
void DLNet2Dot(const DLNetConf net) {
// LogicalGraph::NewSingleton(net);
// LogicalGraph::DeleteSingleton();
}
void GenResNet(const int layer_num) {
Global<JobConf1>::New();
InitPlacementGroup();
std::string op_out;
op_out = Conv2DBlock(true, "conv1", "transpose/out", 64, "same", DATA_FORMAT, 7, 2, 1);
op_out = MaxPooling2D("pool1", op_out, 3, 2, "same", DATA_FORMAT);
op_out = ResidualBlocks(layer_num, op_out);
op_out = AveragePooling2D("pool5", op_out, 7, 1, "valid", DATA_FORMAT);
op_out = FullyConnected("fc1000", op_out, 1000);
// op_out = Softmax("prob", op_out);
PrintProtoToTextFile(Global<JobConf1>::Get()->net(), "./resnet.prototxt");
PrintProtoToTextFile(Global<JobConf1>::Get()->placement(), "./resnet_placement.prototxt");
Global<JobConf1>::Delete();
// DLNet2csv(resnet);
// DLNet2Dot(resnet);
}
} // namespace oneflow
DEFINE_int32(layer_num, 50, "ResNet layer number:50, 101, 152");
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true);
oneflow::GenResNet(FLAGS_layer_num);
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册