提交 e71948f1 编写于 作者: F fengjiayi

Refine random crop

1. Add a new attribute named 'startuo_seed' to RandomCropOp. If the input
'Seed' is not initialized, the 'startup_seed' will be used to replace
it.

2. Refine CustomReader. Add a member variable 'scope_' to it. The
'scope_' will act as the global scope of preprocesing, making it
possiable to save something cross batches.
上级 e7faae01
......@@ -37,6 +37,11 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("SeedOut", "The random seed after random cropping.")
.AsIntermediate();
AddAttr<std::vector<int>>("shape", "The shape of a cropped instance.");
AddAttr<int>("startup_seed",
"If the input 'Seed' is not initialized, the 'startup_seed' "
"will be used to replace it. Even so, the seed after random "
"crop will also be outputed to the 'SeedOut'.")
.SetDefault(0);
AddComment(R"DOC(
This operator takes a batch of instance, and do random cropping on each instance.
It means that cropping positions differs on each instance, which is determined
......@@ -49,8 +54,6 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
class RandomCropOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
auto seed_dim = ctx->GetInputDim("Seed");
PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1);
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size()));
......@@ -62,7 +65,6 @@ class RandomCropOpInferShape : public framework::InferShapeBase {
out_dim[x_i] = shape[shape_i];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
ctx->SetOutputDim("SeedOut", framework::make_ddim({1}));
}
};
......
......@@ -142,16 +142,22 @@ template <typename DeviceContext, typename T>
class RandomCropKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
int64_t seed = 0;
if (platform::is_cpu_place(seed_tensor.place())) {
seed = *seed_tensor.data<int64_t>();
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
if (seed_tensor.IsInitialized()) {
if (platform::is_cpu_place(seed_tensor.place())) {
seed = *seed_tensor.data<int64_t>();
} else {
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
"your program";
framework::LoDTensor cpu_seed;
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
seed = *cpu_seed.data<int64_t>();
}
} else {
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
"your program";
framework::LoDTensor cpu_seed;
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
seed = *cpu_seed.data<int64_t>();
VLOG(5) << "WARNING: The input 'Seed' is not initialized, use attribute "
"'startup_seed' instead.";
seed = ctx.Attr<int>("startup_seed");
}
auto shape = ctx.Attr<std::vector<int>>("shape");
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
......@@ -171,7 +177,7 @@ class RandomCropKernel : public framework::OpKernel<T> {
engine.discard(functor.prod_batchsize_dims_ *
(functor.rank_ - functor.num_batchsize_dims_));
*ctx.Output<framework::LoDTensor>("SeedOut")->mutable_data<int64_t>(
platform::CPUPlace()) = engine();
framework::make_ddim({1}), platform::CPUPlace()) = engine();
}
};
......
......@@ -39,6 +39,7 @@ class CustomReader : public framework::DecoratedReader {
const framework::ProgramDesc program_;
int sub_block_id_;
framework::Executor exe_;
framework::Scope scope_;
std::vector<std::string> source_var_names_;
std::vector<std::string> sink_var_names_;
......@@ -158,20 +159,20 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
// The scope for CustomReader's sub-block should be independent and shouldn't
// be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent.
framework::Scope scope;
framework::Scope& exe_scope = scope_.NewScope();
// 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) {
framework::Variable* var = scope.Var(source_var_names_[i]);
framework::Variable* var = exe_scope.Var(source_var_names_[i]);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod());
}
// 2. Run the sub-block.
exe_.Run(program_, &scope, sub_block_id_, false, true);
exe_.Run(program_, &exe_scope, sub_block_id_, false, true);
// 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
const auto& tensor = detail::Ref(scope.FindVar(sink_var_names_[i]))
const auto& tensor = detail::Ref(exe_scope.FindVar(sink_var_names_[i]))
.Get<framework::LoDTensor>();
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
}
......
......@@ -23,6 +23,7 @@ from layer_function_generator import autodoc, templatedoc
from tensor import concat
import utils
import random
from .. import unique_name
__all__ = [
'fc',
......@@ -846,7 +847,7 @@ def crf_decoding(input, param_attr, label=None):
Returns:
Variable: ${viterbi_path_comment}
Examples:
.. code-block:: python
......@@ -1084,7 +1085,7 @@ def chunk_eval(input,
Here is a NER example of labeling for these tagging schemes:
.. code-block:: python
====== ====== ====== ===== == ============ ===== ===== ===== == =========
Li Ming works at Agricultural Bank of China in Beijing.
====== ====== ====== ===== == ============ ===== ===== ===== == =========
......@@ -1110,7 +1111,7 @@ def chunk_eval(input,
is the num of chunk types, and `tag_type` get its value from the following table.
.. code-block:: python
Scheme Begin Inside End Single
plain 0 - - -
IOB 0 1 - -
......@@ -1146,7 +1147,7 @@ def chunk_eval(input,
tuple: tuple containing: precision, recall, f1_score,
num_infer_chunks, num_label_chunks,
num_correct_chunks
Examples:
.. code-block:: python
......@@ -1266,7 +1267,7 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
param_attr (ParamAttr|None): attributes for parameter
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn \
library is installed. Default: True
Returns:
Variable: output of sequence_softmax
......@@ -4143,7 +4144,7 @@ def one_hot(input, depth):
Examples:
.. code-block:: python
label = layers.data(name="label", shape=[1], dtype="float32")
one_hot_label = layers.one_hot(input=label, depth=10)
"""
......@@ -4862,40 +4863,32 @@ def random_crop(x, shape, seed=None):
Returns:
${out_comment}
Examples:
>>> img = fluid.layers.data("img", [3, 256, 256])
>>> cropped_img = fluid.layers.random_crop(img, shape=[3, 224, 224])
"""
helper = LayerHelper("random_crop", **locals())
dtype = helper.input_dtype()
dtype = x.dtype
out = helper.create_tmp_variable(dtype)
if seed is None:
seed = random.randint(-65536, 65535)
op_attrs = {"shape": shape}
if isinstance(seed, int):
seed_value = seed
seed = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="fill_constant",
inputs={},
outputs={"Out": seed},
attrs={
"dtype": seed.dtype,
"shape": [1],
"value": float(seed_value),
"force_cpu": True
})
op_attrs["startup_seed"] = seed
seed = helper.create_variable(
name=unique_name.generate("random_crop_seed"),
dtype="int64",
persistable=True)
elif not isinstance(seed, Variable):
raise ValueError("'seed' must be a Variable or an int.")
seed_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="random_crop",
inputs={"X": x,
"Seed": seed},
outputs={"Out": out,
"SeedOut": seed_out},
attrs={"shape": shape})
"SeedOut": seed},
attrs=op_attrs)
return out
......@@ -4961,7 +4954,7 @@ def mean_iou(input, label, num_classes):
semantic image segmentation, which first computes the IOU for each
semantic class and then computes the average over classes.
IOU is defined as follows:
.. math::
IOU = \\frac{true\_positiv}{(true\_positive + false\_positive + false\_negative)}.
......@@ -4984,7 +4977,7 @@ def mean_iou(input, label, num_classes):
Examples:
.. code-block:: python
iou, wrongs, corrects = fluid.layers.mean_iou(predict, label, num_classes)
"""
helper = LayerHelper('mean_iou', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册