diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e3d2e5377eac49003b0082c39c9dd0460e2acd92..f87d5521492418d2daf5b7fba1500c4bb31e10f5 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -469,6 +469,7 @@ class RuntimeInferShapeContext : public InferShapeContext { protected: DDim GetDim(const std::string& name) const override { Variable* var = scope_.FindVar(name); + PADDLE_ENFORCE_NOT_NULL(var); if (var->IsType()) { return var->Get().dims(); } else if (var->IsType()) { diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b14b559e31dd422f8ebe4002988a9746dfdf28a2 --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/operators/random_crop_op.h" + +namespace paddle { +namespace operators { + +class RandomCropOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "A batch of instances to random crop."); + AddInput("Seed", "The random seed."); + AddOutput("Out", "The cropped instance batch."); + AddOutput("SeedOut", "The random seed after random cropping.") + .AsDispensable(); + AddAttr>("shape", "The shape of a cropped instance."); + AddComment(R"DOC( + This operator takes a batch of instance, and do random cropping on each instance. + It means that cropping positions differs on each instance, which is determined + by an uniform random generator. All cropped instances have the same shape, which + is determined by the operator's attribute 'shape'. + )DOC"); + } +}; + +class RandomCropOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + auto seed_dim = ctx->GetInputDim("Seed"); + PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1); + auto shape = ctx->Attrs().Get>("shape"); + auto x_dim = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GT(x_dim.size(), static_cast(shape.size())); + auto out_dim = framework::vectorize2int(x_dim); + for (size_t i = 1; i <= shape.size(); ++i) { + size_t x_i = x_dim.size() - i; + size_t shape_i = shape.size() - i; + PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]); + out_dim[x_i] = shape[shape_i]; + } + ctx->SetOutputDim("Out", framework::make_ddim(out_dim)); + ctx->SetOutputDim("SeedOut", framework::make_ddim({1})); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace f = paddle::framework; +REGISTER_OPERATOR(random_crop, ops::RandomCropOp, ops::RandomCropOpMaker, + ops::RandomCropOpInferShape, f::EmptyGradOpMaker); + +template +using Kernel = ops::RandomCropKernel; +REGISTER_OP_CPU_KERNEL(random_crop, Kernel, Kernel, Kernel, + Kernel, Kernel); diff --git a/paddle/fluid/operators/random_crop_op.cu b/paddle/fluid/operators/random_crop_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6fc9bedc55b4d349ddf3d109c7f9049113235f0c --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/random_crop_op.h" + +namespace ops = paddle::operators; +template +using Kernel = ops::RandomCropKernel; +REGISTER_OP_CUDA_KERNEL(random_crop, Kernel, Kernel, Kernel, + Kernel, Kernel); diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f3261cbdc986b0cc724315c1eb92b8b84e18c742 --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.h @@ -0,0 +1,181 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" +#ifdef PADDLE_WITH_CUDA +#include +#endif + +namespace paddle { +namespace operators { + +template +struct Random; + +template <> +struct Random { + using Engine = std::minstd_rand; + + template + using UniformIntDist = std::uniform_int_distribution; +}; + +#ifdef PADDLE_WITH_CUDA +template <> +struct Random { + using Engine = thrust::minstd_rand; + + template + using UniformIntDist = thrust::uniform_int_distribution; +}; +#endif + +template +HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out, + const size_t* out_dims, int i, int rank, + size_t prod_x_remain, + size_t prod_out_remain, + const size_t* offsets) { + size_t x_dim_i = x_dims[i]; + size_t out_dim_i = out_dims[i]; + size_t x_stride = prod_x_remain / x_dim_i; + size_t out_stride = prod_out_remain / out_dim_i; + size_t offset_i = offsets[i]; + + if (i == rank - 1) { + PADDLE_ASSERT(x_stride == 1 && out_stride == 1); + x += offset_i; + for (size_t j = 0; j < out_dim_i; ++j) { + *out++ = *x++; + } + } else { + x += offset_i * x_stride; + for (size_t j = 0; j < out_dim_i; ++j) { + StridedMemcpy(x, x_dims, out, out_dims, i + 1, rank, x_stride, + out_stride, offsets); + x += x_stride; + out += out_stride; + } + } +} + +template +struct RandomCropFunctor { + const T* x_; + T* out_; + size_t x_dims_[9]; + size_t out_dims_[9]; + int num_batchsize_dims_; + int rank_; + int64_t seed_; + + size_t prod_batchsize_dims_; + size_t prod_x_ins_dims_; + size_t prod_out_ins_dims_; + + RandomCropFunctor(const T* x, T* out, const framework::DDim& x_dims, + const framework::DDim& out_dims, int num_batchsize_dims, + int64_t seed) + : x_(x), + out_(out), + num_batchsize_dims_(num_batchsize_dims), + rank_(x_dims.size()), + seed_(seed) { + PADDLE_ENFORCE_EQ(x_dims.size(), out_dims.size()); + PADDLE_ENFORCE_GT(rank_, num_batchsize_dims_); + prod_batchsize_dims_ = 1; + prod_x_ins_dims_ = 1; + prod_out_ins_dims_ = 1; + for (size_t i = 0; i < static_cast(rank_); ++i) { + size_t x_dim_i = x_dims[i]; + size_t out_dim_i = out_dims[i]; + x_dims_[i] = x_dim_i; + out_dims_[i] = out_dim_i; + if (i < static_cast(num_batchsize_dims_)) { + PADDLE_ENFORCE_EQ(x_dim_i, out_dim_i); + prod_batchsize_dims_ *= x_dim_i; + } else { + prod_x_ins_dims_ *= x_dim_i; + prod_out_ins_dims_ *= out_dim_i; + } + } + } + + HOSTDEVICE void operator()(size_t ins_idx) { + typename Random::Engine engine(seed_); + engine.discard(ins_idx * (rank_ - num_batchsize_dims_)); + size_t offsets[9]; + for (int i = num_batchsize_dims_; i < rank_; ++i) { + typename Random::template UniformIntDist dist( + 0, x_dims_[i] - out_dims_[i]); + offsets[i - num_batchsize_dims_] = dist(engine); + } + + const T* x = x_ + ins_idx * prod_x_ins_dims_; + T* out = out_ + ins_idx * prod_out_ins_dims_; + + StridedMemcpy(x, x_dims_ + num_batchsize_dims_, out, + out_dims_ + num_batchsize_dims_, 0, + rank_ - num_batchsize_dims_, prod_x_ins_dims_, + prod_out_ins_dims_, offsets); + } +}; + +template +class RandomCropKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto& seed_tensor = detail::Ref(ctx.Input("Seed")); + int64_t seed = 0; + if (platform::is_cpu_place(seed_tensor.place())) { + seed = *seed_tensor.data(); + } else { + LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify " + "your program"; + framework::LoDTensor cpu_seed; + framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed); + seed = *cpu_seed.data(); + } + auto shape = ctx.Attr>("shape"); + auto& x = detail::Ref(ctx.Input("X")); + auto& out = detail::Ref(ctx.Output("Out")); + + int num_batchsize_dims = x.dims().size() - shape.size(); + RandomCropFunctor functor( + x.data(), out.mutable_data(ctx.GetPlace()), x.dims(), out.dims(), + num_batchsize_dims, seed); + platform::ForRange for_range( + ctx.template device_context(), + functor.prod_batchsize_dims_); + + for_range(functor); + + Random::Engine engine(seed); + engine.discard(functor.prod_batchsize_dims_ * + (functor.rank_ - functor.num_batchsize_dims_)); + *ctx.Output("SeedOut")->mutable_data( + platform::CPUPlace()) = engine(); + } +}; + +// TODO(fengjiayi): Backward of random crop op + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e18b7acdeed4856ffdfbcc583c7d5e8bd2d79fb9..f049dd6fd9ec7485bba17feb09e8b68550353bc3 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -82,6 +82,7 @@ __all__ = [ 'roi_pool', 'dice_loss', 'upsampling_bilinear2d', + 'random_crop', ] @@ -154,7 +155,8 @@ def fc(input, Examples: .. code-block:: python - data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32") + data = fluid.layers.data( + name="data", shape=[32, 32], dtype="float32") fc = fluid.layers.fc(input=data, size=1000, act="tanh") """ @@ -349,7 +351,8 @@ def dynamic_lstm(input, cell_activation(str): The activation for cell output. Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". candidate_activation(str): The activation for candidate hidden state. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". dtype(str): Data type. Choices = ["float32", "float64"], default "float32". name(str|None): A name for this layer(optional). If set None, the layer @@ -516,10 +519,12 @@ def dynamic_lstmp(input, cell_activation(str): The activation for cell output. Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". candidate_activation(str): The activation for candidate hidden state. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". proj_activation(str): The activation for projection output. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". dtype(str): Data type. Choices = ["float32", "float64"], default "float32". name(str|None): A name for this layer(optional). If set None, the layer @@ -2174,7 +2179,8 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_mean(x) # [0.4375] fluid.layers.reduce_mean(x, dim=0) # [0.15, 0.25, 0.55, 0.8] fluid.layers.reduce_mean(x, dim=-1) # [0.475, 0.4] - fluid.layers.reduce_mean(x, dim=1, keep_dim=True) # [[0.475], [0.4]] + fluid.layers.reduce_mean( + x, dim=1, keep_dim=True) # [[0.475], [0.4]] # x is a Tensor variable with shape [2, 2, 2] and elements as below: # [[[1.0, 2.0], [3.0, 4.0]], @@ -2393,7 +2399,8 @@ def split(input, num_or_sections, dim=-1, name=None): x0.shape # [3, 3, 5] x1.shape # [3, 3, 5] x2.shape # [3, 3, 5] - x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1) + x0, x1, x2 = fluid.layers.split( + x, num_or_sections=[2, 3, 4], dim=1) x0.shape # [3, 2, 5] x1.shape # [3, 3, 5] x2.shape # [3, 4, 5] @@ -3305,7 +3312,8 @@ def softmax_with_cross_entropy(logits, label, soft_label=False): data = fluid.layers.data(name='data', shape=[128], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') fc = fluid.layers.fc(input=data, size=100) - out = fluid.layers.softmax_with_cross_entropy(logits=fc, label=label) + out = fluid.layers.softmax_with_cross_entropy( + logits=fc, label=label) """ helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_tmp_variable(dtype=logits.dtype) @@ -3352,7 +3360,8 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): .. code-block:: python data = fluid.layers.data(name='data', shape=[128], dtype='float32') - label = fluid.layers.data(name='label', shape=[100], dtype='float32') + label = fluid.layers.data( + name='label', shape=[100], dtype='float32') fc = fluid.layers.fc(input=data, size=100) out = fluid.layers.smooth_l1(x=fc, y=label) """ @@ -3674,7 +3683,8 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): Examples: .. code-block:: python - data = fluid.layers.data(name="data", shape=[3, 112, 112], dtype="float32") + data = fluid.layers.data( + name="data", shape=[3, 112, 112], dtype="float32") lrn = fluid.layers.lrn(input=data) """ helper = LayerHelper('lrn', **locals()) @@ -3929,10 +3939,10 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): Bilinear interpolation is an extension of linear interpolation for interpolating functions of two variables (e.g. H-direction and W-direction in this layer) on a rectilinear 2D grid. - + For details, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation - + Args: input (Variable): The input tensor of bilinear interpolation, This is a 4-D tensor of the shape @@ -3950,7 +3960,7 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): Returns: out (Variable): The output is a 4-D tensor of the shape (num_batches, channls, out_h, out_w). - + Examples: .. code-block:: python @@ -3983,3 +3993,32 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): attrs={"out_h": out_h, "out_w": out_w}) return out + + +def random_crop(input, shape, seed=1): + helper = LayerHelper("random_crop", **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + if isinstance(seed, int): + seed_value = seed + seed = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="fill_constant", + inputs={}, + outputs={"Out": seed}, + attrs={ + "dtype": seed.dtype, + "shape": [1], + "value": float(seed_value) + }) + elif not isinstance(seed, Variable): + raise ValueError("'seed' must be a Variable or an int.") + seed_out = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="random_crop", + inputs={"X": input, + "Seed": seed}, + outputs={"Out": out, + "SeedOut": seed_out}, + attrs={"shape": shape}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_random_crop_op.py b/python/paddle/fluid/tests/unittests/test_random_crop_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1c708d0386da4028f1f3d177d0a3fd494c077c6e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_random_crop_op.py @@ -0,0 +1,46 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest + + +class TestRandomCropOp(OpTest): + def setUp(self): + to_crop = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]] * + 5).astype("float32") + self.possible_res = [ + np.array([[1, 2, 3], [5, 6, 7]]), np.array([[2, 3, 4], [6, 7, 8]]), + np.array([[5, 6, 7], [9, 10, 11]]), + np.array([[6, 7, 8], [10, 11, 12]]) + ] + self.op_type = "random_crop" + self.inputs = {'X': to_crop, 'Seed': np.array([10])} + self.outputs = {'Out': np.array([]), 'SeedOut': np.array([])} + self.attrs = {'shape': [2, 3]} + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + out = np.array(outs[1]) + for ins in out[:]: + is_equal = [(ins == res).all() for res in self.possible_res] + self.assertIn(True, is_equal) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/codestyle/docstring_checker.pyc b/tools/codestyle/docstring_checker.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ce612ca2318ccb9b9f28d51cb93ce8e5e1d0680 Binary files /dev/null and b/tools/codestyle/docstring_checker.pyc differ