提交 43ee7582 编写于 作者: S Shiyuan Shang-Guan 提交者: Jinhui Yuan

Fix bug in model parallel (#1345)

* fix conv in model parallel

* add TODO


Former-commit-id: 5ed0f04822a94ab9941fda91c8ba8fb18c36aeeb
上级 b3fb9acf
......@@ -65,7 +65,7 @@ void ConvKernel<DeviceType::kGPU, T>::KernelInitWithCudnn(const ParallelContext*
new CudnnConvDesc(GetDataType<T>::value, in_shape, this->GetCustomizedOpConf()));
if (this->template GetValFromCustomizedOpConf<bool>("use_bias")) {
int32_t filters = this->template GetValFromCustomizedOpConf<int32_t>("filters");
int32_t filters = Shape(this->GetConvKernelConf().bias()).At(0);
if ((this->OpKernelDim() == 1) || (this->OpKernelDim() == 2)) {
if (data_format == "channels_first") {
this->bias_desc_.reset(
......
......@@ -10,14 +10,15 @@ message ConvKernelConf {
required ShapeProto in = 1;
required ShapeProto out = 2;
required ShapeProto weight = 3;
required int32 dim = 4;
repeated int32 pad_small_side = 5;
repeated int32 pad_large_side = 6;
repeated int32 dilation_rate = 7;
repeated int32 strides = 8;
optional int32 cudnn_fwd_algo = 9 [default = -1];
optional int32 cudnn_bwd_filter_algo = 10 [default = -1];
optional int32 cudnn_bwd_data_algo = 11 [default = -1];
optional ShapeProto bias = 4;
required int32 dim = 5;
repeated int32 pad_small_side = 6;
repeated int32 pad_large_side = 7;
repeated int32 dilation_rate = 8;
repeated int32 strides = 9;
optional int32 cudnn_fwd_algo = 10 [default = -1];
optional int32 cudnn_bwd_filter_algo = 11 [default = -1];
optional int32 cudnn_bwd_data_algo = 12 [default = -1];
}
message DropoutKernelConf {
......
......@@ -244,6 +244,9 @@ void ConvOp<NDims>::GenKernelConfWithCudnn(
GetBlobDesc4BnInOp("in")->shape().ToProto(conv_conf->mutable_in());
GetBlobDesc4BnInOp("out")->shape().ToProto(conv_conf->mutable_out());
GetBlobDesc4BnInOp("weight")->shape().ToProto(conv_conf->mutable_weight());
if (GetValFromCustomizedConf<bool>("use_bias")) {
GetBlobDesc4BnInOp("bias")->shape().ToProto(conv_conf->mutable_bias());
}
std::vector<int32_t> pad_small_side;
std::vector<int32_t> pad_large_side;
......
......@@ -105,8 +105,10 @@ void Operator::InferBwBufBlobDescs(std::function<BlobDesc*(const std::string&)>
}
void Operator::FixParallelDesc(ParallelDesc* pr_desc) const {
// TODO(shiyuan): When all ops are model-parallel, some ops without a model will be forced to
// data-parallel, causing the data to be split.
if (model_bns().empty() && const_model_bns().empty()) {
pr_desc->set_policy(ParallelPolicy::kDataParallel); // TODO(shiyuan)
pr_desc->set_policy(ParallelPolicy::kDataParallel);
}
if (pr_desc->policy() == kModelParallel && MaxModelSplitNum() != -1) {
pr_desc->RemoveNeedlessDevice(op_name(), MaxModelSplitNum());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册