From 85a41df32d2793da5c1c49b9c36a3781567f4a7e Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Wed, 16 May 2018 15:14:16 +0800 Subject: [PATCH] Init commit --- paddle/fluid/operators/random_crop_op.cc | 59 ++++++++ paddle/fluid/operators/random_crop_op.h | 167 +++++++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 paddle/fluid/operators/random_crop_op.cc create mode 100644 paddle/fluid/operators/random_crop_op.h diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc new file mode 100644 index 00000000000..cb4bdde0eea --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/operators/random_crop_op.h" +#include + +namespace paddle { +namespace operators { +class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", ""); + AddOutput("Y", ""); + AddInput("Seed", ""); + AddOutput("SeedOut", "").AsDispensable(); + AddAttr>("shape", ""); + } +}; + +class RandomCropOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* context) const override { + auto shape = context->Attrs().Get>("shape"); + auto x_dim = context->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dim.size(), static_cast(shape.size())); + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == -1) { + shape[i] = static_cast(x_dim[i]); + } else { + PADDLE_ENFORCE_GE(x_dim[i], shape[i]); + } + } + context->SetOutputDim("Y", framework::make_ddim(shape)); + context->SetOutputDim("SeedOut", framework::make_ddim({1})); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace f = paddle::framework; +REGISTER_OPERATOR(random_crop, f::OperatorWithKernel, ops::RandomCropOpMaker, + ops::RandomCropOpInferShape); +template +using Kernel = ops::RandomCropKernel; + +REGISTER_OP_CPU_KERNEL(random_crop, Kernel, Kernel, Kernel, + Kernel, Kernel); diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h new file mode 100644 index 00000000000..86a22227f3f --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.h @@ -0,0 +1,167 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" +#include "thrust/random.h" + +namespace paddle { +namespace operators { + +template +struct Random; + +template <> +struct Random { + using Engine = std::minstd_rand; + + template + using UniformIntDist = std::uniform_int_distribution; +}; + +template <> +struct Random { + using Engine = thrust::minstd_rand; + + template + using UniformIntDist = thrust::uniform_int_distribution; +}; + +template +HOSTDEVICE inline void RandomCropImpl(const T* x, size_t* x_dim, T* out, + size_t* out_dim, int i, int rank, + int64_t prod_x_remain, + int64_t prod_out_remain, size_t* offset) { + size_t x_length = x_dim[rank]; + size_t out_length = out_dim[rank]; + + int64_t x_stride = prod_x_remain / x_length; + int64_t out_stride = prod_out_remain / out_length; + size_t offset_i = offset[i]; + if (x_stride == 1 && out_stride == 1) { + // In the final stage, copy from offset. + x += offset_i; + for (size_t i = 0; i < out_length; ++i) { + *out++ = *x++; + } + } else { + x += offset_i * x_stride; + for (size_t i = 0; i < out_length; ++i) { + RandomCropImpl(x, x_dim, out, out_dim, i + 1, rank, x_stride, + out_stride, offset); + x += x_stride; + out += out_stride; + } + } +} + +template +struct RandomCropFunctor { + const T* x_; + T* out_; + size_t x_dim_[9]; + size_t out_dim_[9]; + size_t prod_same_dim_; + + size_t prod_x_dim_; + size_t prod_out_dim_; + + int num_same_dim_; + int rank_; + + int64_t seed_; + + RandomCropFunctor(const T* x, T* out, int64_t seed) + : x_(x), + out_(out), + prod_same_dim_(1), + prod_x_dim_(1), + prod_out_dim_(1), + seed_(seed) { + std::fill(x_dim_, x_dim_ + sizeof(x_dim_) / sizeof(size_t), 0); + std::fill(out_dim_, out_dim_ + sizeof(out_dim_) / sizeof(size_t), 0); + } + + HOSTDEVICE void operator()(size_t i) { + typename Random::Engine engine(seed_); + engine.discard(i * (rank_ - num_same_dim_)); + + int64_t prod_x_unsame = (prod_x_dim_ / prod_same_dim_); + int64_t prod_out_unsame = (prod_out_dim_ / prod_same_dim_); + + const T* x = x_ + i * prod_x_unsame; + T* out = out_ + i * prod_out_unsame; + + size_t offset[9]; + for (int i = num_same_dim_; i < rank_; ++i) { + typename Random::template UniformIntDist dist( + 0, x_dim_[i] - out_dim_[i]); + offset[i] = dist(engine); + } + RandomCropImpl(x, x_dim_, out, out_dim_, num_same_dim_, rank_, + prod_x_unsame, prod_out_unsame, offset); + } +}; + +template +class RandomCropKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& context) const { + int64_t seed = + *context.Input("Seed")->data(); + auto& x = detail::Ref(context.Input("X")); + auto& out = detail::Ref(context.Output("Out")); + + RandomCropFunctor functor{ + x.data(), out.mutable_data(context.GetPlace()), seed}; + + auto& out_dim = out.dims(); + auto& x_dim = x.dims(); + + auto rank = x_dim.size(); + while (rank-- > 0) { + functor.x_dim_[rank] = x_dim[rank]; + functor.out_dim_[rank] = out_dim[rank]; + functor.prod_x_dim_ *= x_dim[rank]; + functor.prod_out_dim_ *= out_dim[rank]; + if (x_dim[rank] != out_dim[rank]) { + PADDLE_ENFORCE_EQ(functor.prod_same_dim_, 1); + functor.num_same_dim_ = rank; + } else { + functor.prod_same_dim_ *= out_dim[rank]; + } + } + functor.rank_ = x_dim.size(); + + platform::ForRange for_range( + context.template device_context(), + functor.prod_same_dim_); + + for_range(functor); + + Random::Engine engine(seed); + engine.discard(functor.prod_same_dim_ * + (functor.rank_ - functor.num_same_dim_)); + + *context.Output("SeedOut")->mutable_data( + platform::CPUPlace()) = engine(); + } +}; + +} // namespace operators +} // namespace paddle -- GitLab