未验证 提交 b4f28ccc 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #11632 from JiayiFeng/some_small_fixes

Some small fixes
......@@ -70,6 +70,7 @@ $$Out = values$$
namespace ops = paddle::operators;
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker);
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<int>,
ops::AssignValueKernel<float>);
......@@ -37,6 +37,11 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("SeedOut", "The random seed after random cropping.")
.AsIntermediate();
AddAttr<std::vector<int>>("shape", "The shape of a cropped instance.");
AddAttr<int>("startup_seed",
"If the input 'Seed' is not initialized, the 'startup_seed' "
"will be used to replace it. Even so, the seed after random "
"crop will also be outputed to the 'SeedOut'.")
.SetDefault(0);
AddComment(R"DOC(
This operator takes a batch of instance, and do random cropping on each instance.
It means that cropping positions differs on each instance, which is determined
......@@ -49,8 +54,6 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
class RandomCropOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
auto seed_dim = ctx->GetInputDim("Seed");
PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1);
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size()));
......@@ -62,7 +65,6 @@ class RandomCropOpInferShape : public framework::InferShapeBase {
out_dim[x_i] = shape[shape_i];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
ctx->SetOutputDim("SeedOut", framework::make_ddim({1}));
}
};
......
......@@ -142,16 +142,22 @@ template <typename DeviceContext, typename T>
class RandomCropKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
int64_t seed = 0;
if (platform::is_cpu_place(seed_tensor.place())) {
seed = *seed_tensor.data<int64_t>();
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
if (seed_tensor.IsInitialized()) {
if (platform::is_cpu_place(seed_tensor.place())) {
seed = *seed_tensor.data<int64_t>();
} else {
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
"your program";
framework::LoDTensor cpu_seed;
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
seed = *cpu_seed.data<int64_t>();
}
} else {
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
"your program";
framework::LoDTensor cpu_seed;
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
seed = *cpu_seed.data<int64_t>();
VLOG(5) << "WARNING: The input 'Seed' is not initialized, use attribute "
"'startup_seed' instead.";
seed = ctx.Attr<int>("startup_seed");
}
auto shape = ctx.Attr<std::vector<int>>("shape");
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
......@@ -171,7 +177,7 @@ class RandomCropKernel : public framework::OpKernel<T> {
engine.discard(functor.prod_batchsize_dims_ *
(functor.rank_ - functor.num_batchsize_dims_));
*ctx.Output<framework::LoDTensor>("SeedOut")->mutable_data<int64_t>(
platform::CPUPlace()) = engine();
framework::make_ddim({1}), platform::CPUPlace()) = engine();
}
};
......
......@@ -39,6 +39,7 @@ class CustomReader : public framework::DecoratedReader {
const framework::ProgramDesc program_;
int sub_block_id_;
framework::Executor exe_;
framework::Scope scope_;
std::vector<std::string> source_var_names_;
std::vector<std::string> sink_var_names_;
......@@ -158,23 +159,24 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
// The scope for CustomReader's sub-block should be independent and shouldn't
// be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent.
framework::Scope scope;
framework::Scope* exe_scope = &scope_.NewScope();
// 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) {
framework::Variable* var = scope.Var(source_var_names_[i]);
framework::Variable* var = exe_scope->Var(source_var_names_[i]);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod());
}
// 2. Run the sub-block.
exe_.Run(program_, &scope, sub_block_id_, false, true);
exe_.Run(program_, exe_scope, sub_block_id_, false, true);
// 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
const auto& tensor = detail::Ref(scope.FindVar(sink_var_names_[i]))
const auto& tensor = detail::Ref(exe_scope->FindVar(sink_var_names_[i]))
.Get<framework::LoDTensor>();
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
}
scope_.DeleteScope(exe_scope);
}
} // namespace reader
......
......@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 3;
static constexpr size_t kCacheSize = 5;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 1; // kCacheSize - 2
static constexpr size_t kChannelSize = 3; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
......
......@@ -110,7 +110,7 @@ class BlockGuardServ(BlockGuard):
class ListenAndServ(object):
"""
**ListenAndServ Layer**
ListenAndServ is used to create a rpc server bind and listen
on specific TCP port, this server will run the sub-block when
received variables from clients.
......@@ -212,7 +212,7 @@ def Send(endpoints, send_vars, sync=True):
of send_vars to send
send_vars (list): variables to send to server
sync (bool): whether to wait the request finish
"""
assert (type(send_vars) == list)
......@@ -469,10 +469,13 @@ def open_files(filenames,
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer.
buffer_size(int|None): The size of prefetch buffer. If it is setted None,
buffer size will be thread_num * 3.
Default: None
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Default: True
Returns:
Variable: A Reader Variable via which we can get file data.
......@@ -492,7 +495,7 @@ def open_files(filenames,
image, label = fluid.layers.io.read_file(reader)
"""
if buffer_size is None:
buffer_size = thread_num
buffer_size = thread_num * 3
if isinstance(filenames, basestring):
filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
......
......@@ -23,6 +23,7 @@ from layer_function_generator import autodoc, templatedoc
from tensor import concat
import utils
import random
from .. import unique_name
__all__ = [
'fc',
......@@ -4896,34 +4897,26 @@ def random_crop(x, shape, seed=None):
>>> cropped_img = fluid.layers.random_crop(img, shape=[3, 224, 224])
"""
helper = LayerHelper("random_crop", **locals())
dtype = helper.input_dtype()
dtype = x.dtype
out = helper.create_tmp_variable(dtype)
if seed is None:
seed = random.randint(-65536, 65535)
op_attrs = {"shape": shape}
if isinstance(seed, int):
seed_value = seed
seed = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="fill_constant",
inputs={},
outputs={"Out": seed},
attrs={
"dtype": seed.dtype,
"shape": [1],
"value": float(seed_value),
"force_cpu": True
})
op_attrs["startup_seed"] = seed
seed = helper.create_variable(
name=unique_name.generate("random_crop_seed"),
dtype="int64",
persistable=True)
elif not isinstance(seed, Variable):
raise ValueError("'seed' must be a Variable or an int.")
seed_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="random_crop",
inputs={"X": x,
"Seed": seed},
outputs={"Out": out,
"SeedOut": seed_out},
attrs={"shape": shape})
"SeedOut": seed},
attrs=op_attrs)
return out
......
......@@ -155,7 +155,7 @@ def cast(x, dtype):
Examples:
.. code-block:: python
data = fluid.layers.data(name='x', shape=[13], dtype='float32')
result = fluid.layers.cast(x=data, dtype='float64')
"""
......@@ -188,7 +188,7 @@ def concat(input, axis=0, name=None):
Examples:
.. code-block:: python
out = fluid.layers.concat(input=[Efirst, Esecond, Ethird, Efourth])
"""
helper = LayerHelper('concat', **locals())
......@@ -238,7 +238,7 @@ def sums(input, out=None):
return out
def assign(input, output):
def assign(input, output=None):
"""
**Assign**
......@@ -246,7 +246,7 @@ def assign(input, output):
Args:
input(Variable|numpy.ndarray): The source variable
output(Variable): The destination variable
output(Variable|None): The destination variable
Returns:
Variable: The destination variable that was supplied as the *output*.
......@@ -259,6 +259,8 @@ def assign(input, output):
fluid.layers.assign(hidden, out)
"""
helper = LayerHelper('assign', **locals())
if output is None:
output = helper.create_tmp_variable(dtype=input.dtype)
if isinstance(input, Variable):
helper.append_op(
type='assign', inputs={'X': [input]}, outputs={'Out': [output]})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册