提交 eb053a62 编写于 作者: X xulei2020

add filterOp code

上级 27a88a6b
......@@ -116,20 +116,14 @@ Status FilterOp::WorkerEntry(int32_t worker_id) {
continue;
}
// Thread local variables to avoid lock. When in_columns_ is empty and workers will write
// the name of the first column into input_columns (thread local) instead of in_columns_ (thread global).
std::vector<std::string> input_columns = in_columns_;
// Indices of the columns to process.
std::vector<size_t> to_process_indices;
RETURN_IF_NOT_OK(WorkerEntryInit(in_buffer.get(), &to_process_indices, &input_columns));
RETURN_IF_NOT_OK(CheckColumns(in_buffer.get(), &in_columns_));
// if the databuffer was all filtered, it is marked as kFilterEmpty.
// if the databuffer was partially filtered, it is marked as kFilterPartial.
// if the databuffer was not filtered, it is marked as kFilterFull.
int32_t num_rows = in_buffer->NumRows();
std::unique_ptr<TensorQTable> new_tensor_table;
RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), to_process_indices, &new_tensor_table));
RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), &new_tensor_table));
if (new_tensor_table->empty()) {
RETURN_IF_NOT_OK(
......@@ -147,17 +141,22 @@ Status FilterOp::WorkerEntry(int32_t worker_id) {
return Status::OK();
}
Status FilterOp::WorkerCompute(DataBuffer *in_buffer, const std::vector<size_t> &to_proess_indices,
std::unique_ptr<TensorQTable> *out) {
Status FilterOp::WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out) {
*out = std::make_unique<TensorQTable>();
int32_t num_rows = in_buffer->NumRows();
for (int32_t i = 0; i < num_rows; i++) {
TensorRow to_process;
TensorRow cur_row;
RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row));
(void)std::transform(to_proess_indices.begin(), to_proess_indices.end(), std::back_inserter(to_process),
[&cur_row](const size_t &it) -> std::shared_ptr<Tensor> { return cur_row[it]; });
if (in_columns_.empty() == true) {
MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table.";
to_process = cur_row;
} else {
std::unordered_map<std::string, int32_t> col_map = in_buffer->column_name_map();
(void)std::transform(
in_columns_.begin(), in_columns_.end(), std::back_inserter(to_process),
[&cur_row, &col_map](const auto &it) -> std::shared_ptr<Tensor> { return cur_row[col_map[it]]; });
}
bool predicate = true;
RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate));
if (predicate) {
......@@ -202,9 +201,8 @@ Status FilterOp::Collector() {
return Status::OK();
}
// initialize some internal data structure used by WorkerEntry().
Status FilterOp::WorkerEntryInit(const DataBuffer *in_buf, std::vector<size_t> *to_process_indices,
std::vector<std::string> *input_columns) {
// Private function for checking the column legality.
Status FilterOp::CheckColumns(const DataBuffer *in_buf, std::vector<std::string> *input_columns) {
int32_t num_rows = in_buf->NumRows();
int32_t num_cols = in_buf->NumCols();
if (num_rows == 0 || num_cols == 0) {
......@@ -213,24 +211,6 @@ Status FilterOp::WorkerEntryInit(const DataBuffer *in_buf, std::vector<size_t> *
std::unordered_map<std::string, int32_t> col_name_id_map = in_buf->column_name_map();
// Check if there is invalid column name in the inColumns.
RETURN_IF_NOT_OK(ValidateInColumns(col_name_id_map, input_columns));
if (input_columns->empty()) {
MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table.";
// sort the input colunms by column index.
std::vector<std::pair<std::string, int32_t>> sort_vec(col_name_id_map.begin(), col_name_id_map.end());
std::sort(sort_vec.begin(), sort_vec.end(),
[](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
return a.second < b.second;
});
(void)std::transform(sort_vec.begin(), sort_vec.end(), std::back_inserter(*input_columns),
[](const auto &it) -> std::string { return it.first; });
}
// initialize to_process_indices.
(void)std::transform(input_columns->begin(), input_columns->end(), std::back_inserter(*to_process_indices),
[&col_name_id_map](const auto &it) -> size_t { return col_name_id_map[it]; });
return Status::OK();
}
......
......@@ -141,8 +141,7 @@ class FilterOp : public ParallelOp {
// @param to_proess_indices Indices of columns to be processed.
// @param out data buffer that are filtered by predicate.
// @return Status The error code return.
Status WorkerCompute(DataBuffer *in_buffer, const std::vector<size_t> &to_proess_indices,
std::unique_ptr<TensorQTable> *out);
Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out);
// Collector databuffer.
// @return Status The error code return.
......@@ -166,13 +165,12 @@ class FilterOp : public ParallelOp {
Status ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map,
std::vector<std::string> *input_columns);
// Private function that initialize some internal data structure used by WorkerEntry().
// Private function for checking the column legality
// @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory
// and is not shared with other threads.
// @param[out] to_process_indices Indices of columns that will feed to predicate.
// @param input_columns The vector of input column names used in the current thread.
Status WorkerEntryInit(const DataBuffer *in_buf, std::vector<size_t> *to_process_indices,
std::vector<std::string> *input_columns);
Status CheckColumns(const DataBuffer *in_buf, std::vector<std::string> *input_columns);
};
} // namespace dataset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册