diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index cb4bdde0eea1e220cbdaaa9d4f5d020d2589f615..b9367f1d224bcceae178a5baece66fb45c43f62e 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -12,36 +12,52 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/operators/random_crop_op.h" -#include namespace paddle { namespace operators { + +class RandomCropOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", ""); - AddOutput("Y", ""); + AddOutput("Out", ""); AddInput("Seed", ""); AddOutput("SeedOut", "").AsDispensable(); AddAttr>("shape", ""); + AddComment(""); } }; class RandomCropOpInferShape : public framework::InferShapeBase { public: - void operator()(framework::InferShapeContext* context) const override { - auto shape = context->Attrs().Get>("shape"); - auto x_dim = context->GetInputDim("X"); - PADDLE_ENFORCE_EQ(x_dim.size(), static_cast(shape.size())); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == -1) { - shape[i] = static_cast(x_dim[i]); - } else { - PADDLE_ENFORCE_GE(x_dim[i], shape[i]); - } + void operator()(framework::InferShapeContext* ctx) const override { + auto seed_dim = ctx->GetInputDim("Seed"); + PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1); + auto shape = ctx->Attrs().Get>("shape"); + auto x_dim = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GT(x_dim.size(), static_cast(shape.size())); + auto out_dim = framework::vectorize2int(x_dim); + for (size_t i = 1; i <= shape.size(); ++i) { + size_t x_i = x_dim.size() - i; + size_t shape_i = shape.size() - i; + PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]); + out_dim[x_i] = shape[shape_i]; } - context->SetOutputDim("Y", framework::make_ddim(shape)); - context->SetOutputDim("SeedOut", framework::make_ddim({1})); + ctx->SetOutputDim("Out", framework::make_ddim(out_dim)); + ctx->SetOutputDim("SeedOut", framework::make_ddim({1})); } }; @@ -50,8 +66,8 @@ class RandomCropOpInferShape : public framework::InferShapeBase { namespace ops = paddle::operators; namespace f = paddle::framework; -REGISTER_OPERATOR(random_crop, f::OperatorWithKernel, ops::RandomCropOpMaker, - ops::RandomCropOpInferShape); +REGISTER_OPERATOR(random_crop, ops::RandomCropOp, ops::RandomCropOpMaker, + ops::RandomCropOpInferShape, f::EmptyGradOpMaker); template using Kernel = ops::RandomCropKernel; diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h index 86a22227f3f9a0e335694cde876b2aa763bc55a2..8764bd0bc7848f682f419d7096a41d8d7d349c64 100644 --- a/paddle/fluid/operators/random_crop_op.h +++ b/paddle/fluid/operators/random_crop_op.h @@ -14,11 +14,14 @@ #pragma once +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h" -#include "thrust/random.h" +#ifdef PADDLE_WITH_CUDA +#include +#endif namespace paddle { namespace operators { @@ -34,6 +37,7 @@ struct Random { using UniformIntDist = std::uniform_int_distribution; }; +#ifdef PADDLE_WITH_CUDA template <> struct Random { using Engine = thrust::minstd_rand; @@ -41,29 +45,31 @@ struct Random { template using UniformIntDist = thrust::uniform_int_distribution; }; +#endif template -HOSTDEVICE inline void RandomCropImpl(const T* x, size_t* x_dim, T* out, - size_t* out_dim, int i, int rank, - int64_t prod_x_remain, - int64_t prod_out_remain, size_t* offset) { - size_t x_length = x_dim[rank]; - size_t out_length = out_dim[rank]; - - int64_t x_stride = prod_x_remain / x_length; - int64_t out_stride = prod_out_remain / out_length; - size_t offset_i = offset[i]; - if (x_stride == 1 && out_stride == 1) { - // In the final stage, copy from offset. +HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out, + const size_t* out_dims, int i, int rank, + size_t prod_x_remain, + size_t prod_out_remain, + const size_t* offsets) { + size_t x_dim_i = x_dims[i]; + size_t out_dim_i = out_dims[i]; + size_t x_stride = prod_x_remain / x_dim_i; + size_t out_stride = prod_out_remain / out_dim_i; + size_t offset_i = offsets[i]; + + if (i == rank - 1) { + PADDLE_ENFORCE(x_stride == 1 && out_stride == 1); x += offset_i; - for (size_t i = 0; i < out_length; ++i) { + for (size_t j = 0; j < out_dim_i; ++j) { *out++ = *x++; } } else { x += offset_i * x_stride; - for (size_t i = 0; i < out_length; ++i) { - RandomCropImpl(x, x_dim, out, out_dim, i + 1, rank, x_stride, - out_stride, offset); + for (size_t j = 0; j < x_dim_i; ++j) { + StridedMemcpy(x, x_dims, out, out_dims, i + 1, rank, x_stride, + out_stride, offsets); x += x_stride; out += out_stride; } @@ -74,94 +80,96 @@ template struct RandomCropFunctor { const T* x_; T* out_; - size_t x_dim_[9]; - size_t out_dim_[9]; - size_t prod_same_dim_; - - size_t prod_x_dim_; - size_t prod_out_dim_; - - int num_same_dim_; + size_t x_dims_[9]; + size_t out_dims_[9]; + int num_batchsize_dims_; int rank_; - int64_t seed_; - RandomCropFunctor(const T* x, T* out, int64_t seed) + size_t prod_x_dims_; + size_t prod_out_dims_; + size_t prod_batchsize_dims_; + size_t prod_x_ins_dims_; + size_t prod_out_ins_dims_; + + RandomCropFunctor(const T* x, T* out, const framework::DDim& x_dims, + const framework::DDim& out_dims, int num_batchsize_dims, + int64_t seed) : x_(x), out_(out), - prod_same_dim_(1), - prod_x_dim_(1), - prod_out_dim_(1), + num_batchsize_dims_(num_batchsize_dims), + rank_(x_dims.size()), seed_(seed) { - std::fill(x_dim_, x_dim_ + sizeof(x_dim_) / sizeof(size_t), 0); - std::fill(out_dim_, out_dim_ + sizeof(out_dim_) / sizeof(size_t), 0); + PADDLE_ENFORCE_EQ(x_dims.size(), out_dims.size()); + PADDLE_ENFORCE_GT(rank_, num_batchsize_dims_); + prod_batchsize_dims_ = 1; + prod_x_ins_dims_ = 1; + prod_out_ins_dims_ = 1; + for (size_t i = 0; i < rank_; ++i) { + size_t x_dim_i = x_dims[i]; + size_t out_dim_i = out_dims[i]; + x_dims_[i] = x_dim_i; + out_dims_[i] = out_dim_i; + if (i < num_batchsize_dims_) { + PADDLE_ENFORCE_EQ(x_dim_i, out_dim_i); + prod_batchsize_dims_ *= x_dim_i; + } else { + prod_x_ins_dims_ *= x_dim_i; + prod_out_ins_dims_ *= out_dim_i; + } + } + prod_x_dims_ = prod_batchsize_dims_ * prod_x_ins_dims_; + prod_out_dims_ = prod_batchsize_dims_ * prod_out_ins_dims_; } - HOSTDEVICE void operator()(size_t i) { + HOSTDEVICE void operator()(size_t ins_idx) { typename Random::Engine engine(seed_); - engine.discard(i * (rank_ - num_same_dim_)); - - int64_t prod_x_unsame = (prod_x_dim_ / prod_same_dim_); - int64_t prod_out_unsame = (prod_out_dim_ / prod_same_dim_); - - const T* x = x_ + i * prod_x_unsame; - T* out = out_ + i * prod_out_unsame; - - size_t offset[9]; - for (int i = num_same_dim_; i < rank_; ++i) { + engine.discard(ins_idx * (rank_ - num_batchsize_dims_)); + size_t offsets[9]; + for (int i = num_batchsize_dims_; i < rank_; ++i) { typename Random::template UniformIntDist dist( - 0, x_dim_[i] - out_dim_[i]); - offset[i] = dist(engine); + 0, x_dims_[i] - out_dims_[i]); + offsets[i] = dist(engine); } - RandomCropImpl(x, x_dim_, out, out_dim_, num_same_dim_, rank_, - prod_x_unsame, prod_out_unsame, offset); + + const T* x = x_ + ins_idx * prod_x_ins_dims_; + T* out = out_ + ins_idx * prod_out_ins_dims_; + + StridedMemcpy(x, x_dims_ + num_batchsize_dims_, out, + out_dims_ + num_batchsize_dims_, 0, + rank_ - num_batchsize_dims_, prod_x_ins_dims_, + prod_out_ins_dims_, offsets); } }; template class RandomCropKernel : public framework::OpKernel { public: - virtual void Compute(const framework::ExecutionContext& context) const { - int64_t seed = - *context.Input("Seed")->data(); - auto& x = detail::Ref(context.Input("X")); - auto& out = detail::Ref(context.Output("Out")); - - RandomCropFunctor functor{ - x.data(), out.mutable_data(context.GetPlace()), seed}; - - auto& out_dim = out.dims(); - auto& x_dim = x.dims(); - - auto rank = x_dim.size(); - while (rank-- > 0) { - functor.x_dim_[rank] = x_dim[rank]; - functor.out_dim_[rank] = out_dim[rank]; - functor.prod_x_dim_ *= x_dim[rank]; - functor.prod_out_dim_ *= out_dim[rank]; - if (x_dim[rank] != out_dim[rank]) { - PADDLE_ENFORCE_EQ(functor.prod_same_dim_, 1); - functor.num_same_dim_ = rank; - } else { - functor.prod_same_dim_ *= out_dim[rank]; - } - } - functor.rank_ = x_dim.size(); - + virtual void Compute(const framework::ExecutionContext& ctx) const { + int64_t seed = *ctx.Input("Seed")->data(); + auto shape = ctx.Attr>("shape"); + auto& x = detail::Ref(ctx.Input("X")); + auto& out = detail::Ref(ctx.Output("Out")); + + int num_batchsize_dims = x.dims().size() - shape.size(); + RandomCropFunctor functor( + x.data(), out.mutable_data(ctx.GetPlace()), x.dims(), out.dims(), + num_batchsize_dims, seed); platform::ForRange for_range( - context.template device_context(), - functor.prod_same_dim_); + ctx.template device_context(), + functor.prod_batchsize_dims_); for_range(functor); Random::Engine engine(seed); - engine.discard(functor.prod_same_dim_ * - (functor.rank_ - functor.num_same_dim_)); - - *context.Output("SeedOut")->mutable_data( + engine.discard(functor.prod_batchsize_dims_ * + (functor.rank_ - functor.num_batchsize_dims_)); + *ctx.Output("SeedOut")->mutable_data( platform::CPUPlace()) = engine(); } }; +// TODO(fengjiayi): Backward of random crop op + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 04ee8ac9aee92a0e161e83bf1bb34d3ce727a0fb..42e26dd36653ed39680bc1fca84540349716f006 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -24,64 +24,19 @@ from tensor import concat import utils __all__ = [ - 'fc', - 'embedding', - 'dynamic_lstm', - 'dynamic_lstmp', - 'dynamic_gru', - 'gru_unit', - 'linear_chain_crf', - 'crf_decoding', - 'cos_sim', - 'cross_entropy', - 'square_error_cost', - 'chunk_eval', - 'sequence_conv', - 'conv2d', - 'sequence_pool', - 'sequence_softmax', - 'softmax', - 'pool2d', - 'batch_norm', - 'beam_search_decode', - 'conv2d_transpose', - 'sequence_expand', - 'lstm_unit', - 'reduce_sum', - 'reduce_mean', - 'reduce_max', - 'reduce_min', - 'reduce_prod', - 'sequence_first_step', - 'sequence_last_step', - 'dropout', - 'split', - 'ctc_greedy_decoder', - 'edit_distance', - 'l2_normalize', - 'matmul', - 'topk', - 'warpctc', - 'sequence_reshape', - 'transpose', - 'im2sequence', - 'nce', - 'beam_search', - 'row_conv', - 'multiplex', - 'layer_norm', - 'softmax_with_cross_entropy', - 'smooth_l1', - 'one_hot', - 'autoincreased_step_counter', - 'reshape', - 'lod_reset', - 'lrn', - 'pad', - 'label_smooth', - 'roi_pool', - 'dice_loss', - 'bilinear_interp', + 'fc', 'embedding', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru', + 'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy', + 'square_error_cost', 'chunk_eval', 'sequence_conv', 'conv2d', + 'sequence_pool', 'sequence_softmax', 'softmax', 'pool2d', 'batch_norm', + 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'lstm_unit', + 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', 'reduce_prod', + 'sequence_first_step', 'sequence_last_step', 'dropout', 'split', + 'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'topk', + 'warpctc', 'sequence_reshape', 'transpose', 'im2sequence', 'nce', + 'beam_search', 'row_conv', 'multiplex', 'layer_norm', + 'softmax_with_cross_entropy', 'smooth_l1', 'one_hot', + 'autoincreased_step_counter', 'reshape', 'lod_reset', 'lrn', 'pad', + 'label_smooth', 'roi_pool', 'dice_loss', 'bilinear_interp', 'random_crop' ] @@ -154,7 +109,8 @@ def fc(input, Examples: .. code-block:: python - data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32") + data = fluid.layers.data( + name="data", shape=[32, 32], dtype="float32") fc = fluid.layers.fc(input=data, size=1000, act="tanh") """ @@ -349,7 +305,8 @@ def dynamic_lstm(input, cell_activation(str): The activation for cell output. Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". candidate_activation(str): The activation for candidate hidden state. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". dtype(str): Data type. Choices = ["float32", "float64"], default "float32". name(str|None): A name for this layer(optional). If set None, the layer @@ -516,10 +473,12 @@ def dynamic_lstmp(input, cell_activation(str): The activation for cell output. Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". candidate_activation(str): The activation for candidate hidden state. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". proj_activation(str): The activation for projection output. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". dtype(str): Data type. Choices = ["float32", "float64"], default "float32". name(str|None): A name for this layer(optional). If set None, the layer @@ -2171,7 +2130,8 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_mean(x) # [0.4375] fluid.layers.reduce_mean(x, dim=0) # [0.15, 0.25, 0.55, 0.8] fluid.layers.reduce_mean(x, dim=-1) # [0.475, 0.4] - fluid.layers.reduce_mean(x, dim=1, keep_dim=True) # [[0.475], [0.4]] + fluid.layers.reduce_mean( + x, dim=1, keep_dim=True) # [[0.475], [0.4]] # x is a Tensor variable with shape [2, 2, 2] and elements as below: # [[[1.0, 2.0], [3.0, 4.0]], @@ -2390,7 +2350,8 @@ def split(input, num_or_sections, dim=-1, name=None): x0.shape # [3, 3, 5] x1.shape # [3, 3, 5] x2.shape # [3, 3, 5] - x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1) + x0, x1, x2 = fluid.layers.split( + x, num_or_sections=[2, 3, 4], dim=1) x0.shape # [3, 2, 5] x1.shape # [3, 3, 5] x2.shape # [3, 4, 5] @@ -3300,7 +3261,8 @@ def softmax_with_cross_entropy(logits, label, soft_label=False): data = fluid.layers.data(name='data', shape=[128], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') fc = fluid.layers.fc(input=data, size=100) - out = fluid.layers.softmax_with_cross_entropy(logits=fc, label=label) + out = fluid.layers.softmax_with_cross_entropy( + logits=fc, label=label) """ helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_tmp_variable(dtype=logits.dtype) @@ -3347,7 +3309,8 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): .. code-block:: python data = fluid.layers.data(name='data', shape=[128], dtype='float32') - label = fluid.layers.data(name='label', shape=[100], dtype='float32') + label = fluid.layers.data( + name='label', shape=[100], dtype='float32') fc = fluid.layers.fc(input=data, size=100) out = fluid.layers.smooth_l1(x=fc, y=label) """ @@ -3669,7 +3632,8 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): Examples: .. code-block:: python - data = fluid.layers.data(name="data", shape=[3, 112, 112], dtype="float32") + data = fluid.layers.data( + name="data", shape=[3, 112, 112], dtype="float32") lrn = fluid.layers.lrn(input=data) """ helper = LayerHelper('lrn', **locals()) @@ -3922,10 +3886,10 @@ def bilinear_interp(input, out_h, out_w, name=None): Bilinear interpolation is an extension of linear interpolation for interpolating functions of two variables (e.g. H-direction and W-direction in this layer) on a rectilinear 2D grid. - + For details, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation - + Args: input (Variable): The input tensor of bilinear interpolation, This is a 4-D tensor of the shape @@ -3938,7 +3902,7 @@ def bilinear_interp(input, out_h, out_w, name=None): Returns: out (Variable): The output is a 4-D tensor of the shape (num_batches, channls, out_h, out_w). - + Examples: .. code-block:: python @@ -3954,3 +3918,25 @@ def bilinear_interp(input, out_h, out_w, name=None): attrs={"out_h": out_h, "out_w": out_w}) return out + + +def random_crop(input, shape, seed=0): + helper = LayerHelper("random_crop", **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + if isinstance(seed, int): + seed = helper.create_global_variable( + persistable=True, shape=[1], dtype="int32") + helper.set_variable_initializer( + var=seed, initializer=Constant(value=seed)) + elif not isinstance(seed, Variable): + raise ValueError("'seed' must be a Variable or an int.") + seed_out = helper.create_tmp_variable(dtype="int32") + helper.append_op( + type="random_crop", + inputs={"X": input, + "Seed": seed}, + outputs={"Out": out, + "SeedOut": seed_out}, + attrs={"shape": shape}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_random_crop_op.py b/python/paddle/fluid/tests/unittests/test_random_crop_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e609e2c99fba733325b24991a701270170637bda --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_random_crop_op.py @@ -0,0 +1,34 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest + + +class TestRandomCropOp(OpTest): + def setUp(self): + to_crop = np.random.random((1, 10, 15)).astype("float32") + self.op_type = "random_crop" + self.inputs = {'X': to_crop, 'Seed': np.array([10])} + self.outputs = {'Out': np.array([1, 2, 3]), 'SeedOut': np.array([2])} + self.attrs = {'shape': [5, 5]} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()