提交 46da7c45 编写于 作者: L lixinqi

fix tuple identity bugs

上级 6cc8e013
......@@ -140,12 +140,11 @@ void COCODataset::GetData(int64_t idx, DataInstance* data_inst) const {
auto* image_size_field = data_inst->GetField<DataSourceCase::kImageSize>();
auto* bbox_field = data_inst->GetField<DataSourceCase::kObjectBoundingBox>();
auto* label_field = data_inst->GetField<DataSourceCase::kObjectLabel>();
DataField* segm_field = nullptr;
if (data_inst->HasField<DataSourceCase::kObjectSegmentationAlignedMask>()) {
DataField* segm_field = data_inst->GetField<DataSourceCase::kObjectSegmentation>();
if (segm_field == nullptr
&& data_inst->HasField<DataSourceCase::kObjectSegmentationAlignedMask>()) {
segm_field = data_inst->GetOrCreateField<DataSourceCase::kObjectSegmentation>(
dataset_proto().coco().max_segm_poly_points());
} else {
segm_field = data_inst->GetField<DataSourceCase::kObjectSegmentation>();
}
auto* bbox_list = dynamic_cast<TensorListDataField<float>*>(bbox_field);
if (bbox_list != nullptr) { bbox_list->SetShape(4); }
......
......@@ -25,7 +25,10 @@ class PieceSliceKernel final : public KernelIf<device_type> {
Blob* out_blob = BnInOp2Blob("out_" + std::to_string(i));
auto out_tensor = out_blob->ReserveOneEmptyTensorList();
out_blob->AddTensor(&out_tensor);
out_tensor.set_shape(in_tensor.shape());
DimVector dim_vec;
in_tensor.shape().ToDimVector(&dim_vec);
dim_vec.erase(dim_vec.begin());
out_tensor.set_shape(Shape(dim_vec));
Memcpy<device_type>(ctx.device_ctx, out_tensor.mut_dptr(), in_tensor.dptr(),
in_tensor.ByteSize());
in_blob->MoveToNextTensor(&in_tensor);
......
......@@ -2,6 +2,18 @@
namespace oneflow {
template<DeviceType device_type>
void TupleIdentityKernel<device_type>::ForwardHeader(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const auto& input_bns = this->op_attribute().input_bns();
const auto& output_bns = this->op_attribute().output_bns();
CHECK_EQ(input_bns.size(), output_bns.size());
FOR_RANGE(int, i, 0, input_bns.size()) {
Blob* out_blob = BnInOp2Blob(output_bns.Get(i));
out_blob->CopyHeaderFrom(ctx.device_ctx, BnInOp2Blob(input_bns.Get(i)));
}
}
template<DeviceType device_type>
void TupleIdentityKernel<device_type>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
......
......@@ -14,6 +14,8 @@ class TupleIdentityKernel final : public KernelIf<device_type> {
~TupleIdentityKernel() = default;
private:
void ForwardHeader(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
......
......@@ -513,6 +513,7 @@ if terminal_args.train_with_real_dataset:
random_seed,
shuffle,
group_by_aspect_ratio,
max_segm_poly_points=1024 * 1024,
)
data_loader = flow.data.DataLoader(coco, batch_size, batch_cache_size)
data_loader.add_blob(
......@@ -544,6 +545,14 @@ if terminal_args.train_with_real_dataset:
tensor_list_variable_axis=0,
is_dynamic=True,
)
# data_loader.add_blob(
# "gt_segm_poly",
# data_util.DataSourceCase.kObjectSegmentation,
# shape=(64, 2, 256, 2),
# dtype=flow.double,
# tensor_list_variable_axis=0,
# is_dynamic=True,
# )
data_loader.add_blob(
"gt_segm",
data_util.DataSourceCase.kObjectSegmentationAlignedMask,
......
......@@ -223,11 +223,18 @@ class COCODataset(object):
random_seed,
shuffle=True,
group_by_aspect_ratio=True,
max_segm_poly_points=1024,
name=None,
):
name = name or id_util.UniqueStr("COCODataset_")
self.__dict__.update(locals())
del self.self
self.dataset_dir = dataset_dir
self.annotation_file = annotation_file
self.image_dir = image_dir
self.random_seed = random_seed
self.shuffle = shuffle
self.group_by_aspect_ratio = group_by_aspect_ratio
self.max_segm_poly_points = max_segm_poly_points
self.name = name
def to_proto(self, proto=None):
if proto is None:
......@@ -240,6 +247,7 @@ class COCODataset(object):
proto.coco.annotation_file = self.annotation_file
proto.coco.image_dir = self.image_dir
proto.coco.group_by_aspect_ratio = self.group_by_aspect_ratio
proto.coco.max_segm_poly_points = self.max_segm_poly_points
return proto
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册