diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 26590171bbeaa385ac09b04e5faf483924176598..7fdd1c6b76aebcea757540e7312a679b8c08402a 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include // for exp #include // for memcpy +#include #include #include #include "gflags/gflags.h" diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 728197377df04df8c993a48bc282431473fe9959..d4f1da69f0a6aee0d915e053ee868f69aecd9348 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -16,7 +16,9 @@ function(reader_library TARGET_NAME) endfunction() cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool) +cc_library(ctr_reader SRCS ctr_reader.cc DEPS reader simple_threadpool) reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader) +reader_library(create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) diff --git a/paddle/fluid/operators/reader/create_ctr_reader_op.cc b/paddle/fluid/operators/reader/create_ctr_reader_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e182521f9ab868c4de78d465f92008a84f5403e2 --- /dev/null +++ b/paddle/fluid/operators/reader/create_ctr_reader_op.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/ctr_reader.h" + +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class CreateCTRReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + if (out->Get() != nullptr) return; + + const std::string& queue_name = Input("blocking_queue"); + auto* queue_holder_var = scope.FindVar(queue_name); + PADDLE_ENFORCE_NOT_NULL( + queue_holder_var, + "No LoDTensorBlockingQueueHolder variable with name %s found", + queue_name); + auto* queue_holder = + queue_holder_var->template GetMutable(); + + out->Reset(std::make_shared(queue_holder->GetQueue())); + } +}; + +class CreateCTRReaderOpMaker : public FileReaderMakerBase { + protected: + void Apply() override { + AddInput("blocking_queue", + "Name of the `LoDTensorBlockingQueueHolder` variable"); + + AddComment(R"DOC( + Create CTRReader to support read ctr data with cpp. + )DOC"); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace reader = ::paddle::operators::reader; + +REGISTER_FILE_READER_OPERATOR(create_ctr_reader, reader::CreateCTRReaderOp, + reader::CreateCTRReaderOpMaker); diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc new file mode 100644 index 0000000000000000000000000000000000000000..bcf49fc967cab6a275e154d1b6f18034734f9200 --- /dev/null +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/ctr_reader.h" + +namespace paddle { +namespace operators { +namespace reader {} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h new file mode 100644 index 0000000000000000000000000000000000000000..c3cf78e5f43aa99df42de2981a0a5ec5d96e8133 --- /dev/null +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" + +namespace paddle { +namespace operators { +namespace reader { + +class CTRReader : public framework::FileReader { + public: + explicit CTRReader(const std::shared_ptr& queue) + : framework::FileReader() { + PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); + queue_ = queue; + } + + void ReadNext(std::vector* out) override { + bool success; + *out = queue_->Pop(&success); + if (!success) out->clear(); + } + + ~CTRReader() { queue_->Close(); } + + void Shutdown() override { queue_->Close(); } + + void Start() override { queue_->ReOpen(); } + + private: + std::shared_ptr queue_; +}; + +} // namespace reader +} // namespace operators +} // namespace paddle