提交 e71948f1 编写于 作者: F fengjiayi

Refine random crop

1. Add a new attribute named 'startuo_seed' to RandomCropOp. If the input
'Seed' is not initialized, the 'startup_seed' will be used to replace
it.

2. Refine CustomReader. Add a member variable 'scope_' to it. The
'scope_' will act as the global scope of preprocesing, making it
possiable to save something cross batches.
上级 e7faae01
...@@ -37,6 +37,11 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -37,6 +37,11 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("SeedOut", "The random seed after random cropping.") AddOutput("SeedOut", "The random seed after random cropping.")
.AsIntermediate(); .AsIntermediate();
AddAttr<std::vector<int>>("shape", "The shape of a cropped instance."); AddAttr<std::vector<int>>("shape", "The shape of a cropped instance.");
AddAttr<int>("startup_seed",
"If the input 'Seed' is not initialized, the 'startup_seed' "
"will be used to replace it. Even so, the seed after random "
"crop will also be outputed to the 'SeedOut'.")
.SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
This operator takes a batch of instance, and do random cropping on each instance. This operator takes a batch of instance, and do random cropping on each instance.
It means that cropping positions differs on each instance, which is determined It means that cropping positions differs on each instance, which is determined
...@@ -49,8 +54,6 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,8 +54,6 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
class RandomCropOpInferShape : public framework::InferShapeBase { class RandomCropOpInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override { void operator()(framework::InferShapeContext* ctx) const override {
auto seed_dim = ctx->GetInputDim("Seed");
PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1);
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size())); PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size()));
...@@ -62,7 +65,6 @@ class RandomCropOpInferShape : public framework::InferShapeBase { ...@@ -62,7 +65,6 @@ class RandomCropOpInferShape : public framework::InferShapeBase {
out_dim[x_i] = shape[shape_i]; out_dim[x_i] = shape[shape_i];
} }
ctx->SetOutputDim("Out", framework::make_ddim(out_dim)); ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
ctx->SetOutputDim("SeedOut", framework::make_ddim({1}));
} }
}; };
......
...@@ -142,16 +142,22 @@ template <typename DeviceContext, typename T> ...@@ -142,16 +142,22 @@ template <typename DeviceContext, typename T>
class RandomCropKernel : public framework::OpKernel<T> { class RandomCropKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
int64_t seed = 0; int64_t seed = 0;
if (platform::is_cpu_place(seed_tensor.place())) { auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
seed = *seed_tensor.data<int64_t>(); if (seed_tensor.IsInitialized()) {
if (platform::is_cpu_place(seed_tensor.place())) {
seed = *seed_tensor.data<int64_t>();
} else {
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
"your program";
framework::LoDTensor cpu_seed;
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
seed = *cpu_seed.data<int64_t>();
}
} else { } else {
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify " VLOG(5) << "WARNING: The input 'Seed' is not initialized, use attribute "
"your program"; "'startup_seed' instead.";
framework::LoDTensor cpu_seed; seed = ctx.Attr<int>("startup_seed");
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
seed = *cpu_seed.data<int64_t>();
} }
auto shape = ctx.Attr<std::vector<int>>("shape"); auto shape = ctx.Attr<std::vector<int>>("shape");
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X")); auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
...@@ -171,7 +177,7 @@ class RandomCropKernel : public framework::OpKernel<T> { ...@@ -171,7 +177,7 @@ class RandomCropKernel : public framework::OpKernel<T> {
engine.discard(functor.prod_batchsize_dims_ * engine.discard(functor.prod_batchsize_dims_ *
(functor.rank_ - functor.num_batchsize_dims_)); (functor.rank_ - functor.num_batchsize_dims_));
*ctx.Output<framework::LoDTensor>("SeedOut")->mutable_data<int64_t>( *ctx.Output<framework::LoDTensor>("SeedOut")->mutable_data<int64_t>(
platform::CPUPlace()) = engine(); framework::make_ddim({1}), platform::CPUPlace()) = engine();
} }
}; };
......
...@@ -39,6 +39,7 @@ class CustomReader : public framework::DecoratedReader { ...@@ -39,6 +39,7 @@ class CustomReader : public framework::DecoratedReader {
const framework::ProgramDesc program_; const framework::ProgramDesc program_;
int sub_block_id_; int sub_block_id_;
framework::Executor exe_; framework::Executor exe_;
framework::Scope scope_;
std::vector<std::string> source_var_names_; std::vector<std::string> source_var_names_;
std::vector<std::string> sink_var_names_; std::vector<std::string> sink_var_names_;
...@@ -158,20 +159,20 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -158,20 +159,20 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
// The scope for CustomReader's sub-block should be independent and shouldn't // The scope for CustomReader's sub-block should be independent and shouldn't
// be any other computation scope's child. Otherwise, data preprocessing and // be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent. // compution cannot be concurrent.
framework::Scope scope; framework::Scope& exe_scope = scope_.NewScope();
// 1. Copy LoDTensors from underlying reader's output to source variables. // 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) { for (size_t i = 0; i < source_var_names_.size(); ++i) {
framework::Variable* var = scope.Var(source_var_names_[i]); framework::Variable* var = exe_scope.Var(source_var_names_[i]);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>(); framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]); tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod()); tensor->set_lod(underlying_outs[i].lod());
} }
// 2. Run the sub-block. // 2. Run the sub-block.
exe_.Run(program_, &scope, sub_block_id_, false, true); exe_.Run(program_, &exe_scope, sub_block_id_, false, true);
// 3. Copy LoDTensors from sink variables to out. // 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size()); out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) { for (size_t i = 0; i < sink_var_names_.size(); ++i) {
const auto& tensor = detail::Ref(scope.FindVar(sink_var_names_[i])) const auto& tensor = detail::Ref(exe_scope.FindVar(sink_var_names_[i]))
.Get<framework::LoDTensor>(); .Get<framework::LoDTensor>();
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]); framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
} }
......
...@@ -23,6 +23,7 @@ from layer_function_generator import autodoc, templatedoc ...@@ -23,6 +23,7 @@ from layer_function_generator import autodoc, templatedoc
from tensor import concat from tensor import concat
import utils import utils
import random import random
from .. import unique_name
__all__ = [ __all__ = [
'fc', 'fc',
...@@ -846,7 +847,7 @@ def crf_decoding(input, param_attr, label=None): ...@@ -846,7 +847,7 @@ def crf_decoding(input, param_attr, label=None):
Returns: Returns:
Variable: ${viterbi_path_comment} Variable: ${viterbi_path_comment}
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1084,7 +1085,7 @@ def chunk_eval(input, ...@@ -1084,7 +1085,7 @@ def chunk_eval(input,
Here is a NER example of labeling for these tagging schemes: Here is a NER example of labeling for these tagging schemes:
.. code-block:: python .. code-block:: python
====== ====== ====== ===== == ============ ===== ===== ===== == ========= ====== ====== ====== ===== == ============ ===== ===== ===== == =========
Li Ming works at Agricultural Bank of China in Beijing. Li Ming works at Agricultural Bank of China in Beijing.
====== ====== ====== ===== == ============ ===== ===== ===== == ========= ====== ====== ====== ===== == ============ ===== ===== ===== == =========
...@@ -1110,7 +1111,7 @@ def chunk_eval(input, ...@@ -1110,7 +1111,7 @@ def chunk_eval(input,
is the num of chunk types, and `tag_type` get its value from the following table. is the num of chunk types, and `tag_type` get its value from the following table.
.. code-block:: python .. code-block:: python
Scheme Begin Inside End Single Scheme Begin Inside End Single
plain 0 - - - plain 0 - - -
IOB 0 1 - - IOB 0 1 - -
...@@ -1146,7 +1147,7 @@ def chunk_eval(input, ...@@ -1146,7 +1147,7 @@ def chunk_eval(input,
tuple: tuple containing: precision, recall, f1_score, tuple: tuple containing: precision, recall, f1_score,
num_infer_chunks, num_label_chunks, num_infer_chunks, num_label_chunks,
num_correct_chunks num_correct_chunks
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1266,7 +1267,7 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True): ...@@ -1266,7 +1267,7 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
param_attr (ParamAttr|None): attributes for parameter param_attr (ParamAttr|None): attributes for parameter
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn \ use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn \
library is installed. Default: True library is installed. Default: True
Returns: Returns:
Variable: output of sequence_softmax Variable: output of sequence_softmax
...@@ -4143,7 +4144,7 @@ def one_hot(input, depth): ...@@ -4143,7 +4144,7 @@ def one_hot(input, depth):
Examples: Examples:
.. code-block:: python .. code-block:: python
label = layers.data(name="label", shape=[1], dtype="float32") label = layers.data(name="label", shape=[1], dtype="float32")
one_hot_label = layers.one_hot(input=label, depth=10) one_hot_label = layers.one_hot(input=label, depth=10)
""" """
...@@ -4862,40 +4863,32 @@ def random_crop(x, shape, seed=None): ...@@ -4862,40 +4863,32 @@ def random_crop(x, shape, seed=None):
Returns: Returns:
${out_comment} ${out_comment}
Examples: Examples:
>>> img = fluid.layers.data("img", [3, 256, 256]) >>> img = fluid.layers.data("img", [3, 256, 256])
>>> cropped_img = fluid.layers.random_crop(img, shape=[3, 224, 224]) >>> cropped_img = fluid.layers.random_crop(img, shape=[3, 224, 224])
""" """
helper = LayerHelper("random_crop", **locals()) helper = LayerHelper("random_crop", **locals())
dtype = helper.input_dtype() dtype = x.dtype
out = helper.create_tmp_variable(dtype) out = helper.create_tmp_variable(dtype)
if seed is None: if seed is None:
seed = random.randint(-65536, 65535) seed = random.randint(-65536, 65535)
op_attrs = {"shape": shape}
if isinstance(seed, int): if isinstance(seed, int):
seed_value = seed op_attrs["startup_seed"] = seed
seed = helper.create_tmp_variable(dtype="int64") seed = helper.create_variable(
helper.append_op( name=unique_name.generate("random_crop_seed"),
type="fill_constant", dtype="int64",
inputs={}, persistable=True)
outputs={"Out": seed},
attrs={
"dtype": seed.dtype,
"shape": [1],
"value": float(seed_value),
"force_cpu": True
})
elif not isinstance(seed, Variable): elif not isinstance(seed, Variable):
raise ValueError("'seed' must be a Variable or an int.") raise ValueError("'seed' must be a Variable or an int.")
seed_out = helper.create_tmp_variable(dtype="int64")
helper.append_op( helper.append_op(
type="random_crop", type="random_crop",
inputs={"X": x, inputs={"X": x,
"Seed": seed}, "Seed": seed},
outputs={"Out": out, outputs={"Out": out,
"SeedOut": seed_out}, "SeedOut": seed},
attrs={"shape": shape}) attrs=op_attrs)
return out return out
...@@ -4961,7 +4954,7 @@ def mean_iou(input, label, num_classes): ...@@ -4961,7 +4954,7 @@ def mean_iou(input, label, num_classes):
semantic image segmentation, which first computes the IOU for each semantic image segmentation, which first computes the IOU for each
semantic class and then computes the average over classes. semantic class and then computes the average over classes.
IOU is defined as follows: IOU is defined as follows:
.. math:: .. math::
IOU = \\frac{true\_positiv}{(true\_positive + false\_positive + false\_negative)}. IOU = \\frac{true\_positiv}{(true\_positive + false\_positive + false\_negative)}.
...@@ -4984,7 +4977,7 @@ def mean_iou(input, label, num_classes): ...@@ -4984,7 +4977,7 @@ def mean_iou(input, label, num_classes):
Examples: Examples:
.. code-block:: python .. code-block:: python
iou, wrongs, corrects = fluid.layers.mean_iou(predict, label, num_classes) iou, wrongs, corrects = fluid.layers.mean_iou(predict, label, num_classes)
""" """
helper = LayerHelper('mean_iou', **locals()) helper = LayerHelper('mean_iou', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册