diff --git a/paddle/contrib/inference/CMakeLists.txt b/paddle/contrib/inference/CMakeLists.txt index 5043e4e02c8cb3a6209c3f66d1f1faa6fa943fa6..9c55f189bcc5cbf0ce84f11e9653fa20b84a51f7 100644 --- a/paddle/contrib/inference/CMakeLists.txt +++ b/paddle/contrib/inference/CMakeLists.txt @@ -17,7 +17,7 @@ if(APPLE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move") endif(APPLE) -function(inference_api_test TARGET_NAME TEST_SRC DEP_TEST) +function(inference_api_test TARGET_NAME TEST_SRC) set(options "") set(oneValueArgs "") set(multiValueArgs ARGS) @@ -38,6 +38,8 @@ function(inference_api_test TARGET_NAME TEST_SRC DEP_TEST) SRCS ${TEST_SRC} DEPS paddle_fluid_api paddle_inference_api paddle_inference_api_impl ARGS --dirname=${PYTHON_TESTS_DIR}/book/) + # TODO(panyx0178): Figure out how to add word2vec and image_classification + # as deps. # set_tests_properties(${TARGET_NAME} # PROPERTIES DEPENDS ${DEP_TEST}) endforeach() @@ -57,5 +59,4 @@ cc_test(test_paddle_inference_api DEPS paddle_inference_api) inference_api_test(test_paddle_inference_api_impl - test_paddle_inference_api_impl.cc - test_word2vec) + test_paddle_inference_api_impl.cc) diff --git a/paddle/contrib/inference/paddle_inference_api_impl.cc b/paddle/contrib/inference/paddle_inference_api_impl.cc index e7a0b341dda1ca8d2ccfc0d6c12a7ac3d4c691d5..ebe4c3291802707009f30616463705d966e244d6 100644 --- a/paddle/contrib/inference/paddle_inference_api_impl.cc +++ b/paddle/contrib/inference/paddle_inference_api_impl.cc @@ -102,8 +102,8 @@ bool PaddlePredictorImpl::Run(const std::vector &inputs, Timer timer; timer.tic(); // set feed variable - std::map feed_targets; - std::vector feeds; + std::map feed_targets; + std::vector feeds; if (!SetFeed(inputs, &feeds)) { LOG(ERROR) << "fail to set feed"; return false; @@ -112,8 +112,8 @@ bool PaddlePredictorImpl::Run(const std::vector &inputs, feed_targets[feed_target_names_[i]] = &feeds[i]; } // get fetch variable - std::map fetch_targets; - std::vector fetchs; + std::map fetch_targets; + std::vector fetchs; fetchs.resize(fetch_target_names_.size()); for (size_t i = 0; i < fetch_target_names_.size(); ++i) { fetch_targets[fetch_target_names_[i]] = &fetchs[i]; @@ -149,28 +149,27 @@ bool PaddlePredictorImpl::InitShared() { VLOG(3) << "Predictor::init_shared"; // 1. Define place, executor, scope if (this->config_.device >= 0) { - place_ = paddle::platform::CUDAPlace(); + place_ = platform::CUDAPlace(); } else { - place_ = paddle::platform::CPUPlace(); + place_ = platform::CPUPlace(); } - this->executor_.reset(new paddle::framework::Executor(this->place_)); - this->scope_.reset(new paddle::framework::Scope()); + this->executor_.reset(new framework::Executor(this->place_)); + this->scope_.reset(new framework::Scope()); // Initialize the inference program if (!this->config_.model_dir.empty()) { // Parameters are saved in separate files sited in // the specified `dirname`. - this->inference_program_ = paddle::inference::Load( + this->inference_program_ = inference::Load( this->executor_.get(), this->scope_.get(), this->config_.model_dir); } else if (!this->config_.prog_file.empty() && !this->config_.param_file.empty()) { // All parameters are saved in a single file. // The file names should be consistent with that used // in Python API `fluid.io.save_inference_model`. - this->inference_program_ = - paddle::inference::Load(this->executor_.get(), - this->scope_.get(), - this->config_.prog_file, - this->config_.param_file); + this->inference_program_ = inference::Load(this->executor_.get(), + this->scope_.get(), + this->config_.prog_file, + this->config_.param_file); } this->ctx_ = this->executor_->Prepare(*this->inference_program_, 0); // 3. create variables @@ -185,24 +184,21 @@ bool PaddlePredictorImpl::InitShared() { return true; } -bool PaddlePredictorImpl::SetFeed( - const std::vector &inputs, - std::vector *feeds) { +bool PaddlePredictorImpl::SetFeed(const std::vector &inputs, + std::vector *feeds) { VLOG(3) << "Predictor::set_feed"; if (inputs.size() != feed_target_names_.size()) { LOG(ERROR) << "wrong feed input size."; return false; } for (size_t i = 0; i < feed_target_names_.size(); ++i) { - paddle::framework::LoDTensor input; - paddle::framework::DDim ddim = - paddle::framework::make_ddim(inputs[i].shape); + framework::LoDTensor input; + framework::DDim ddim = framework::make_ddim(inputs[i].shape); void *input_ptr; if (inputs[i].dtype == PaddleDType::INT64) { - input_ptr = - input.mutable_data(ddim, paddle::platform::CPUPlace()); + input_ptr = input.mutable_data(ddim, platform::CPUPlace()); } else if (inputs[i].dtype == PaddleDType::FLOAT32) { - input_ptr = input.mutable_data(ddim, paddle::platform::CPUPlace()); + input_ptr = input.mutable_data(ddim, platform::CPUPlace()); } else { LOG(ERROR) << "unsupported feed type " << inputs[i].dtype; return false; @@ -213,13 +209,12 @@ bool PaddlePredictorImpl::SetFeed( inputs[i].data.data, inputs[i].data.length); feeds->push_back(input); - LOG(ERROR) << "Actual feed type " << feeds->back().type().name(); } return true; } bool PaddlePredictorImpl::GetFetch( - const std::vector &fetchs, + const std::vector &fetchs, std::vector *outputs) { VLOG(3) << "Predictor::get_fetch"; outputs->resize(fetchs.size()); @@ -284,8 +279,9 @@ bool PaddlePredictorImpl::GetFetch( return true; } -std::unique_ptr CreatePaddlePredictorImpl( - const VisConfig &config) { +template <> +std::unique_ptr CreatePaddlePredictor( + const ConfigImpl &config) { VLOG(3) << "create PaddlePredictorImpl"; // 1. GPU memeroy std::vector flags; @@ -299,12 +295,11 @@ std::unique_ptr CreatePaddlePredictorImpl( framework::InitGflags(flags); } - std::unique_ptr predictor( - new PaddlePredictorImpl(config)); - if (!predictor->Init()) { + std::unique_ptr predictor(new PaddlePredictorImpl(config)); + if (!dynamic_cast(predictor.get())->Init()) { return nullptr; } - return predictor; + return std::move(predictor); } } // namespace paddle diff --git a/paddle/contrib/inference/paddle_inference_api_impl.h b/paddle/contrib/inference/paddle_inference_api_impl.h index a0c7ff030735fc1c6b9d717f8f9e4addc7e0c6b0..c545461680723b429b2253392060ea36b84ce708 100644 --- a/paddle/contrib/inference/paddle_inference_api_impl.h +++ b/paddle/contrib/inference/paddle_inference_api_impl.h @@ -29,7 +29,7 @@ namespace paddle { -struct VisConfig : public PaddlePredictor::Config { +struct ConfigImpl : public PaddlePredictor::Config { int device; float fraction_of_gpu_memory; std::string prog_file; @@ -37,12 +37,9 @@ struct VisConfig : public PaddlePredictor::Config { bool share_variables; }; -/* - * Do not use this, just a demo indicating how to customize a Predictor. - */ class PaddlePredictorImpl : public PaddlePredictor { public: - explicit PaddlePredictorImpl(const VisConfig &config) : config_(config) {} + explicit PaddlePredictorImpl(const ConfigImpl &config) : config_(config) {} bool Init(); @@ -56,21 +53,18 @@ class PaddlePredictorImpl : public PaddlePredictor { private: bool InitShared() override; bool SetFeed(const std::vector &input_datas, - std::vector *feeds); - bool GetFetch(const std::vector &fetchs, + std::vector *feeds); + bool GetFetch(const std::vector &fetchs, std::vector *output_data); - VisConfig config_; - paddle::platform::Place place_; - std::unique_ptr executor_; - std::unique_ptr scope_; - std::unique_ptr ctx_; - std::unique_ptr inference_program_; + ConfigImpl config_; + platform::Place place_; + std::unique_ptr executor_; + std::unique_ptr scope_; + std::unique_ptr ctx_; + std::unique_ptr inference_program_; std::vector feed_target_names_; std::vector fetch_target_names_; }; -std::unique_ptr CreatePaddlePredictorImpl( - const VisConfig &config); - } // namespace paddle diff --git a/paddle/contrib/inference/test_paddle_inference_api_impl.cc b/paddle/contrib/inference/test_paddle_inference_api_impl.cc index 2a58f6989d5dad23b2f267adafde2cc105bf5651..096293a4e25df0c78150d85dc091d7ca6539bf40 100644 --- a/paddle/contrib/inference/test_paddle_inference_api_impl.cc +++ b/paddle/contrib/inference/test_paddle_inference_api_impl.cc @@ -40,16 +40,19 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { return pt; } -TEST(paddle_inference_api_impl, word2vec) { - VisConfig config; +ConfigImpl GetConfig() { + ConfigImpl config; config.model_dir = FLAGS_dirname + "word2vec.inference.model"; LOG(INFO) << "dirname " << config.model_dir; config.fraction_of_gpu_memory = 0.15; config.device = 0; config.share_variables = true; + return config; +} - std::unique_ptr predictor = - CreatePaddlePredictorImpl(config); +TEST(paddle_inference_api_impl, word2vec) { + ConfigImpl config = GetConfig(); + std::unique_ptr predictor = CreatePaddlePredictor(config); framework::LoDTensor first_word, second_word, third_word, fourth_word; framework::LoD lod{{0, 1}}; @@ -60,24 +63,91 @@ TEST(paddle_inference_api_impl, word2vec) { SetupLoDTensor(&third_word, lod, static_cast(0), dict_size - 1); SetupLoDTensor(&fourth_word, lod, static_cast(0), dict_size - 1); - std::vector cpu_feeds; - cpu_feeds.push_back(LodTensorToPaddleTensor(&first_word)); - cpu_feeds.push_back(LodTensorToPaddleTensor(&second_word)); - cpu_feeds.push_back(LodTensorToPaddleTensor(&third_word)); - cpu_feeds.push_back(LodTensorToPaddleTensor(&fourth_word)); + std::vector paddle_tensor_feeds; + paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&first_word)); + paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&second_word)); + paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&third_word)); + paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&fourth_word)); + + std::vector outputs; + ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); + ASSERT_EQ(outputs.size(), 1UL); + size_t len = outputs[0].data.length; + float* data = static_cast(outputs[0].data.data); + for (int j = 0; j < len / sizeof(float); ++j) { + ASSERT_LT(data[j], 1.0); + ASSERT_GT(data[j], -1.0); + } + + std::vector cpu_feeds; + cpu_feeds.push_back(&first_word); + cpu_feeds.push_back(&second_word); + cpu_feeds.push_back(&third_word); + cpu_feeds.push_back(&fourth_word); + + framework::LoDTensor output1; + std::vector cpu_fetchs1; + cpu_fetchs1.push_back(&output1); + + TestInference(config.model_dir, cpu_feeds, cpu_fetchs1); + + float* lod_data = output1.data(); + for (size_t i = 0; i < output1.numel(); ++i) { + EXPECT_LT(lod_data[i] - data[i], 1e-3); + EXPECT_GT(lod_data[i] - data[i], -1e-3); + } + + free(outputs[0].data.data); +} + +TEST(paddle_inference_api_impl, image_classification) { + int batch_size = 2; + bool use_mkldnn = false; + bool repeat = false; + ConfigImpl config = GetConfig(); + config.model_dir = + FLAGS_dirname + "image_classification_resnet.inference.model"; + + const bool is_combined = false; + std::vector> feed_target_shapes = + GetFeedTargetShapes(config.model_dir, is_combined); + + framework::LoDTensor input; + // Use normilized image pixels as input data, + // which should be in the range [0.0, 1.0]. + feed_target_shapes[0][0] = batch_size; + framework::DDim input_dims = framework::make_ddim(feed_target_shapes[0]); + SetupTensor( + &input, input_dims, static_cast(0), static_cast(1)); + std::vector cpu_feeds; + cpu_feeds.push_back(&input); + + framework::LoDTensor output1; + std::vector cpu_fetchs1; + cpu_fetchs1.push_back(&output1); + + TestInference(config.model_dir, + cpu_feeds, + cpu_fetchs1, + repeat, + is_combined, + use_mkldnn); + + std::unique_ptr predictor = CreatePaddlePredictor(config); + std::vector paddle_tensor_feeds; + paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&input)); std::vector outputs; - ASSERT_TRUE(predictor->Run(cpu_feeds, &outputs)); + ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); ASSERT_EQ(outputs.size(), 1UL); - for (size_t i = 0; i < outputs.size(); ++i) { - size_t len = outputs[i].data.length; - float* data = static_cast(outputs[i].data.data); - for (size_t j = 0; j < len / sizeof(float); ++j) { - ASSERT_LT(data[j], 1.0); - ASSERT_GT(data[j], -1.0); - } - free(outputs[i].data.data); + size_t len = outputs[0].data.length; + float* data = static_cast(outputs[0].data.data); + float* lod_data = output1.data(); + for (size_t j = 0; j < len / sizeof(float); ++j) { + EXPECT_LT(lod_data[j] - data[j], 1e-10); + EXPECT_GT(lod_data[j] - data[j], -1e-10); } + free(data); } } // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e3d2e5377eac49003b0082c39c9dd0460e2acd92..f87d5521492418d2daf5b7fba1500c4bb31e10f5 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -469,6 +469,7 @@ class RuntimeInferShapeContext : public InferShapeContext { protected: DDim GetDim(const std::string& name) const override { Variable* var = scope_.FindVar(name); + PADDLE_ENFORCE_NOT_NULL(var); if (var->IsType()) { return var->Get().dims(); } else if (var->IsType()) { diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b14b559e31dd422f8ebe4002988a9746dfdf28a2 --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/operators/random_crop_op.h" + +namespace paddle { +namespace operators { + +class RandomCropOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "A batch of instances to random crop."); + AddInput("Seed", "The random seed."); + AddOutput("Out", "The cropped instance batch."); + AddOutput("SeedOut", "The random seed after random cropping.") + .AsDispensable(); + AddAttr>("shape", "The shape of a cropped instance."); + AddComment(R"DOC( + This operator takes a batch of instance, and do random cropping on each instance. + It means that cropping positions differs on each instance, which is determined + by an uniform random generator. All cropped instances have the same shape, which + is determined by the operator's attribute 'shape'. + )DOC"); + } +}; + +class RandomCropOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + auto seed_dim = ctx->GetInputDim("Seed"); + PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1); + auto shape = ctx->Attrs().Get>("shape"); + auto x_dim = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GT(x_dim.size(), static_cast(shape.size())); + auto out_dim = framework::vectorize2int(x_dim); + for (size_t i = 1; i <= shape.size(); ++i) { + size_t x_i = x_dim.size() - i; + size_t shape_i = shape.size() - i; + PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]); + out_dim[x_i] = shape[shape_i]; + } + ctx->SetOutputDim("Out", framework::make_ddim(out_dim)); + ctx->SetOutputDim("SeedOut", framework::make_ddim({1})); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace f = paddle::framework; +REGISTER_OPERATOR(random_crop, ops::RandomCropOp, ops::RandomCropOpMaker, + ops::RandomCropOpInferShape, f::EmptyGradOpMaker); + +template +using Kernel = ops::RandomCropKernel; +REGISTER_OP_CPU_KERNEL(random_crop, Kernel, Kernel, Kernel, + Kernel, Kernel); diff --git a/paddle/fluid/operators/random_crop_op.cu b/paddle/fluid/operators/random_crop_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6fc9bedc55b4d349ddf3d109c7f9049113235f0c --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/random_crop_op.h" + +namespace ops = paddle::operators; +template +using Kernel = ops::RandomCropKernel; +REGISTER_OP_CUDA_KERNEL(random_crop, Kernel, Kernel, Kernel, + Kernel, Kernel); diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f3261cbdc986b0cc724315c1eb92b8b84e18c742 --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.h @@ -0,0 +1,181 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" +#ifdef PADDLE_WITH_CUDA +#include +#endif + +namespace paddle { +namespace operators { + +template +struct Random; + +template <> +struct Random { + using Engine = std::minstd_rand; + + template + using UniformIntDist = std::uniform_int_distribution; +}; + +#ifdef PADDLE_WITH_CUDA +template <> +struct Random { + using Engine = thrust::minstd_rand; + + template + using UniformIntDist = thrust::uniform_int_distribution; +}; +#endif + +template +HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out, + const size_t* out_dims, int i, int rank, + size_t prod_x_remain, + size_t prod_out_remain, + const size_t* offsets) { + size_t x_dim_i = x_dims[i]; + size_t out_dim_i = out_dims[i]; + size_t x_stride = prod_x_remain / x_dim_i; + size_t out_stride = prod_out_remain / out_dim_i; + size_t offset_i = offsets[i]; + + if (i == rank - 1) { + PADDLE_ASSERT(x_stride == 1 && out_stride == 1); + x += offset_i; + for (size_t j = 0; j < out_dim_i; ++j) { + *out++ = *x++; + } + } else { + x += offset_i * x_stride; + for (size_t j = 0; j < out_dim_i; ++j) { + StridedMemcpy(x, x_dims, out, out_dims, i + 1, rank, x_stride, + out_stride, offsets); + x += x_stride; + out += out_stride; + } + } +} + +template +struct RandomCropFunctor { + const T* x_; + T* out_; + size_t x_dims_[9]; + size_t out_dims_[9]; + int num_batchsize_dims_; + int rank_; + int64_t seed_; + + size_t prod_batchsize_dims_; + size_t prod_x_ins_dims_; + size_t prod_out_ins_dims_; + + RandomCropFunctor(const T* x, T* out, const framework::DDim& x_dims, + const framework::DDim& out_dims, int num_batchsize_dims, + int64_t seed) + : x_(x), + out_(out), + num_batchsize_dims_(num_batchsize_dims), + rank_(x_dims.size()), + seed_(seed) { + PADDLE_ENFORCE_EQ(x_dims.size(), out_dims.size()); + PADDLE_ENFORCE_GT(rank_, num_batchsize_dims_); + prod_batchsize_dims_ = 1; + prod_x_ins_dims_ = 1; + prod_out_ins_dims_ = 1; + for (size_t i = 0; i < static_cast(rank_); ++i) { + size_t x_dim_i = x_dims[i]; + size_t out_dim_i = out_dims[i]; + x_dims_[i] = x_dim_i; + out_dims_[i] = out_dim_i; + if (i < static_cast(num_batchsize_dims_)) { + PADDLE_ENFORCE_EQ(x_dim_i, out_dim_i); + prod_batchsize_dims_ *= x_dim_i; + } else { + prod_x_ins_dims_ *= x_dim_i; + prod_out_ins_dims_ *= out_dim_i; + } + } + } + + HOSTDEVICE void operator()(size_t ins_idx) { + typename Random::Engine engine(seed_); + engine.discard(ins_idx * (rank_ - num_batchsize_dims_)); + size_t offsets[9]; + for (int i = num_batchsize_dims_; i < rank_; ++i) { + typename Random::template UniformIntDist dist( + 0, x_dims_[i] - out_dims_[i]); + offsets[i - num_batchsize_dims_] = dist(engine); + } + + const T* x = x_ + ins_idx * prod_x_ins_dims_; + T* out = out_ + ins_idx * prod_out_ins_dims_; + + StridedMemcpy(x, x_dims_ + num_batchsize_dims_, out, + out_dims_ + num_batchsize_dims_, 0, + rank_ - num_batchsize_dims_, prod_x_ins_dims_, + prod_out_ins_dims_, offsets); + } +}; + +template +class RandomCropKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto& seed_tensor = detail::Ref(ctx.Input("Seed")); + int64_t seed = 0; + if (platform::is_cpu_place(seed_tensor.place())) { + seed = *seed_tensor.data(); + } else { + LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify " + "your program"; + framework::LoDTensor cpu_seed; + framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed); + seed = *cpu_seed.data(); + } + auto shape = ctx.Attr>("shape"); + auto& x = detail::Ref(ctx.Input("X")); + auto& out = detail::Ref(ctx.Output("Out")); + + int num_batchsize_dims = x.dims().size() - shape.size(); + RandomCropFunctor functor( + x.data(), out.mutable_data(ctx.GetPlace()), x.dims(), out.dims(), + num_batchsize_dims, seed); + platform::ForRange for_range( + ctx.template device_context(), + functor.prod_batchsize_dims_); + + for_range(functor); + + Random::Engine engine(seed); + engine.discard(functor.prod_batchsize_dims_ * + (functor.rank_ - functor.num_batchsize_dims_)); + *ctx.Output("SeedOut")->mutable_data( + platform::CPUPlace()) = engine(); + } +}; + +// TODO(fengjiayi): Backward of random crop op + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e18b7acdeed4856ffdfbcc583c7d5e8bd2d79fb9..f049dd6fd9ec7485bba17feb09e8b68550353bc3 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -82,6 +82,7 @@ __all__ = [ 'roi_pool', 'dice_loss', 'upsampling_bilinear2d', + 'random_crop', ] @@ -154,7 +155,8 @@ def fc(input, Examples: .. code-block:: python - data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32") + data = fluid.layers.data( + name="data", shape=[32, 32], dtype="float32") fc = fluid.layers.fc(input=data, size=1000, act="tanh") """ @@ -349,7 +351,8 @@ def dynamic_lstm(input, cell_activation(str): The activation for cell output. Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". candidate_activation(str): The activation for candidate hidden state. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". dtype(str): Data type. Choices = ["float32", "float64"], default "float32". name(str|None): A name for this layer(optional). If set None, the layer @@ -516,10 +519,12 @@ def dynamic_lstmp(input, cell_activation(str): The activation for cell output. Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". candidate_activation(str): The activation for candidate hidden state. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". proj_activation(str): The activation for projection output. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". dtype(str): Data type. Choices = ["float32", "float64"], default "float32". name(str|None): A name for this layer(optional). If set None, the layer @@ -2174,7 +2179,8 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_mean(x) # [0.4375] fluid.layers.reduce_mean(x, dim=0) # [0.15, 0.25, 0.55, 0.8] fluid.layers.reduce_mean(x, dim=-1) # [0.475, 0.4] - fluid.layers.reduce_mean(x, dim=1, keep_dim=True) # [[0.475], [0.4]] + fluid.layers.reduce_mean( + x, dim=1, keep_dim=True) # [[0.475], [0.4]] # x is a Tensor variable with shape [2, 2, 2] and elements as below: # [[[1.0, 2.0], [3.0, 4.0]], @@ -2393,7 +2399,8 @@ def split(input, num_or_sections, dim=-1, name=None): x0.shape # [3, 3, 5] x1.shape # [3, 3, 5] x2.shape # [3, 3, 5] - x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1) + x0, x1, x2 = fluid.layers.split( + x, num_or_sections=[2, 3, 4], dim=1) x0.shape # [3, 2, 5] x1.shape # [3, 3, 5] x2.shape # [3, 4, 5] @@ -3305,7 +3312,8 @@ def softmax_with_cross_entropy(logits, label, soft_label=False): data = fluid.layers.data(name='data', shape=[128], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') fc = fluid.layers.fc(input=data, size=100) - out = fluid.layers.softmax_with_cross_entropy(logits=fc, label=label) + out = fluid.layers.softmax_with_cross_entropy( + logits=fc, label=label) """ helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_tmp_variable(dtype=logits.dtype) @@ -3352,7 +3360,8 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): .. code-block:: python data = fluid.layers.data(name='data', shape=[128], dtype='float32') - label = fluid.layers.data(name='label', shape=[100], dtype='float32') + label = fluid.layers.data( + name='label', shape=[100], dtype='float32') fc = fluid.layers.fc(input=data, size=100) out = fluid.layers.smooth_l1(x=fc, y=label) """ @@ -3674,7 +3683,8 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): Examples: .. code-block:: python - data = fluid.layers.data(name="data", shape=[3, 112, 112], dtype="float32") + data = fluid.layers.data( + name="data", shape=[3, 112, 112], dtype="float32") lrn = fluid.layers.lrn(input=data) """ helper = LayerHelper('lrn', **locals()) @@ -3929,10 +3939,10 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): Bilinear interpolation is an extension of linear interpolation for interpolating functions of two variables (e.g. H-direction and W-direction in this layer) on a rectilinear 2D grid. - + For details, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation - + Args: input (Variable): The input tensor of bilinear interpolation, This is a 4-D tensor of the shape @@ -3950,7 +3960,7 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): Returns: out (Variable): The output is a 4-D tensor of the shape (num_batches, channls, out_h, out_w). - + Examples: .. code-block:: python @@ -3983,3 +3993,32 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): attrs={"out_h": out_h, "out_w": out_w}) return out + + +def random_crop(input, shape, seed=1): + helper = LayerHelper("random_crop", **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + if isinstance(seed, int): + seed_value = seed + seed = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="fill_constant", + inputs={}, + outputs={"Out": seed}, + attrs={ + "dtype": seed.dtype, + "shape": [1], + "value": float(seed_value) + }) + elif not isinstance(seed, Variable): + raise ValueError("'seed' must be a Variable or an int.") + seed_out = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="random_crop", + inputs={"X": input, + "Seed": seed}, + outputs={"Out": out, + "SeedOut": seed_out}, + attrs={"shape": shape}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_random_crop_op.py b/python/paddle/fluid/tests/unittests/test_random_crop_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1c708d0386da4028f1f3d177d0a3fd494c077c6e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_random_crop_op.py @@ -0,0 +1,46 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest + + +class TestRandomCropOp(OpTest): + def setUp(self): + to_crop = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]] * + 5).astype("float32") + self.possible_res = [ + np.array([[1, 2, 3], [5, 6, 7]]), np.array([[2, 3, 4], [6, 7, 8]]), + np.array([[5, 6, 7], [9, 10, 11]]), + np.array([[6, 7, 8], [10, 11, 12]]) + ] + self.op_type = "random_crop" + self.inputs = {'X': to_crop, 'Seed': np.array([10])} + self.outputs = {'Out': np.array([]), 'SeedOut': np.array([])} + self.attrs = {'shape': [2, 3]} + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + out = np.array(outs[1]) + for ins in out[:]: + is_equal = [(ins == res).all() for res in self.possible_res] + self.assertIn(True, is_equal) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/codestyle/docstring_checker.pyc b/tools/codestyle/docstring_checker.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ce612ca2318ccb9b9f28d51cb93ce8e5e1d0680 Binary files /dev/null and b/tools/codestyle/docstring_checker.pyc differ