diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 951514710b29d1d6fb3165acd92b65d997f43772..3f576a45169c9c7e4581d304efc7cf0bca1b310a 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -118,6 +118,8 @@ paddle.fluid.layers.reduce_mean (ArgSpec(args=['input', 'dim', 'keep_dim', 'name paddle.fluid.layers.reduce_max (ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)), ('document', '66a622db727551761ce4eb73eaa7f6a4')) paddle.fluid.layers.reduce_min (ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)), ('document', 'd50ac552b5d131468ed466d08bb2d38c')) paddle.fluid.layers.reduce_prod (ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)), ('document', 'fcd8301a0ce15f219c7a4bcd0c1e8eca')) +paddle.fluid.layers.reduce_all (ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)), ('document', '646ca4d4a2cc16084f59de44b6927eca')) +paddle.fluid.layers.reduce_any (ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)), ('document', 'f36661060aeeaf6c6b1331e41b3726fa')) paddle.fluid.layers.sequence_first_step (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '2b290d3d77882bfe9bb8d331cac8cdd3')) paddle.fluid.layers.sequence_last_step (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'c16a892f44f7fe71bfa5afc32d3f34ce')) paddle.fluid.layers.sequence_slice (ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'fdcea0e8b5bc7d8d4b1b072c521014e6')) @@ -204,6 +206,7 @@ paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'sha paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'a418e3ccb5e2ac21bd60f5cc221d5860')) paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '01dbb91e7c74cb11336cd531013de51a')) paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '17db0f814eb7bb5a3fac1ca6e60e16d8')) +paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'ee1386c42ecc8f424fe3fb21862fefc2')) paddle.fluid.layers.logical_and (ArgSpec(args=['x', 'y', 'out', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'cdcf20c494c92060d10feb9374532f42')) paddle.fluid.layers.logical_or (ArgSpec(args=['x', 'y', 'out', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '0eae3f726a4afe590757552fa3ced012')) paddle.fluid.layers.logical_xor (ArgSpec(args=['x', 'y', 'out', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'b0daaa3fa4a0aa62f9b58c43d959eb25')) @@ -272,6 +275,7 @@ paddle.fluid.layers.has_inf (ArgSpec(args=['x'], varargs=None, keywords=None, de paddle.fluid.layers.has_nan (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '2e53e83127dbfd86e7098bdfe9a549e8')) paddle.fluid.layers.isfinite (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '0a437011c3906079fd8947ed3e52d292')) paddle.fluid.layers.range (ArgSpec(args=['start', 'end', 'step', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '2ec937ede953ded2fdff2675883900bb')) +paddle.fluid.layers.linspace (ArgSpec(args=['start', 'stop', 'num', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '495e21e9a848c2d075a102802fc67756')) paddle.fluid.layers.While.__init__ (ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.While.block (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.Switch.__init__ (ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 28a37f331c100695f0ffec7288db84f4493d68a0..12ce99c8788625e2aae6e07abdea565bb2c2ebb9 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -31,10 +31,10 @@ namespace paddle { namespace framework { namespace ir { namespace { -void SortHelper( - const std::map> &adj_list, - ir::Node *node, std::unordered_set *visited, - std::vector *ret) { +void SortHelper(const std::map, + ir::NodeComp> &adj_list, + ir::Node *node, std::unordered_set *visited, + std::vector *ret) { visited->insert(node); for (auto adj : adj_list.at(node)) { @@ -50,7 +50,8 @@ void SortHelper( bool HasCircleHelper( ir::Node *node, - const std::map> &adj_list, + const std::map, ir::NodeComp> + &adj_list, std::unordered_set *visited, std::unordered_set *in_trace, std::vector> *circles) { @@ -84,7 +85,8 @@ bool HasCircleHelper( } bool HasCircleInternal( - const std::map> &adj_list, + const std::map, ir::NodeComp> + &adj_list, std::vector> *circles) { std::unordered_set visited; std::unordered_set in_trace; @@ -107,8 +109,8 @@ bool FindCircleSubGraph(const Graph &graph, } std::vector TopologySortOperations(const Graph &graph) { - std::map> adj_list = - BuildOperationAdjList(graph); + std::map, ir::NodeComp> + adj_list = BuildOperationAdjList(graph); PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr)); std::unordered_set visited; std::vector ret; @@ -117,34 +119,30 @@ std::vector TopologySortOperations(const Graph &graph) { SortHelper(adj_list, adj.first, &visited, &ret); } } + return ret; } // Build operator inlink edge table. -std::map> BuildOperationAdjList( - const Graph &graph) { - std::map> adj_list; +std::map, ir::NodeComp> +BuildOperationAdjList(const Graph &graph) { + std::map, ir::NodeComp> + adj_list; for (auto &n : graph.Nodes()) { if (!n->IsOp()) continue; if (adj_list.find(n) == adj_list.end()) { - adj_list[n] = std::unordered_set(); + adj_list[n] = std::set(); } - std::vector nodes; for (auto &var : n->inputs) { for (auto &adj_n : var->inputs) { PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) << " -> " << n->Name() << reinterpret_cast(n) << " via " << var->Name() << reinterpret_cast(var); - nodes.push_back(adj_n); + adj_list[n].insert(adj_n); } } - std::sort(nodes.begin(), nodes.end(), [](ir::Node *node1, ir::Node *node2) { - return node1->id() > node2->id(); - }); - adj_list[n].insert(std::make_move_iterator(nodes.begin()), - std::make_move_iterator(nodes.end())); } return adj_list; } diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index 214de9ec7d85aee6021b18866295777e317aa79d..849a9c3be6904f3f9c3669d8fc9d750154863031 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include "paddle/fluid/framework/ir/graph.h" @@ -25,6 +26,13 @@ namespace paddle { namespace framework { namespace ir { +// Compare nodes via node id. +struct NodeComp { + bool operator()(ir::Node *const &node1, ir::Node *const &node2) const { + return node1->id() < node2->id(); + } +}; + // Test if the graph contains circle. bool HasCircle(const Graph &graph); @@ -57,8 +65,8 @@ std::vector TopologyVarientSort(const Graph &graph, SortKind sort_kind); void CleanIndividualNodes(Graph *graph); // Build an adjacency list of operations for the `graph`. -std::map> BuildOperationAdjList( - const Graph &graph); +std::map, ir::NodeComp> +BuildOperationAdjList(const Graph &graph); template std::vector FilterByNodeWrapper(const Graph &graph) { diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index e6f5b15af8cd440a9304235acfe62787c5f1b134..1ea93b7638a85e67bcc85a0c0e130d636938d6c5 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -241,6 +241,7 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs, outputs_ = outputs; attrs_ = attrs; need_update_ = true; + block_ = nullptr; } OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) { diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 6942604b0723f8665f0e8b058d48a5356a1a01f4..0155609a029664da2c3d4c90a152ec556927c32d 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -259,6 +259,9 @@ bool AnalysisPredictor::SetFeed(const std::vector &inputs, return false; } + PADDLE_ENFORCE_NOT_NULL(input_ptr); + PADDLE_ENFORCE_NOT_NULL(inputs[i].data.data()); + if (platform::is_cpu_place(place_)) { // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. std::memcpy(static_cast(input_ptr), inputs[i].data.data(), diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc index 7d57b6ec74468dbdb0519f85140629a0ac01c18d..fc2d7b48c2a1f89232dcb96d1899667230e2ddda 100644 --- a/paddle/fluid/inference/api/api.cc +++ b/paddle/fluid/inference/api/api.cc @@ -54,6 +54,7 @@ PaddleBuf &PaddleBuf::operator=(const PaddleBuf &other) { memory_owned_ = other.memory_owned_; } else { Resize(other.length()); + PADDLE_ENFORCE(!(other.length() > 0 && other.data() == nullptr)); memcpy(data_, other.data(), other.length()); length_ = other.length(); memory_owned_ = true; diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 54f40563c3662af24e794422be4d3262d86c76a7..56996c5cff88f5b4a9094291a09996f8b8d70a23 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -169,6 +169,7 @@ std::unique_ptr NativePaddlePredictor::Clone() { std::unique_ptr cls(new NativePaddlePredictor(config_)); // Hot fix the bug that result diff in multi-thread. // TODO(Superjomn) re-implement a real clone here. + PADDLE_ENFORCE_NOT_NULL(dynamic_cast(cls.get())); if (!dynamic_cast(cls.get())->Init(nullptr)) { LOG(ERROR) << "fail to call Init"; return nullptr; @@ -210,6 +211,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, return false; } + PADDLE_ENFORCE_NOT_NULL(input_ptr); + PADDLE_ENFORCE_NOT_NULL(inputs[i].data.data()); if (platform::is_cpu_place(place_)) { // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. std::memcpy(static_cast(input_ptr), inputs[i].data.data(), @@ -316,6 +319,8 @@ std::unique_ptr CreatePaddlePredictor< } std::unique_ptr predictor(new NativePaddlePredictor(config)); + PADDLE_ENFORCE_NOT_NULL( + dynamic_cast(predictor.get())); if (!dynamic_cast(predictor.get())->Init(nullptr)) { return nullptr; } diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc index 9f23b9f037bcaeb758312d011067ae29c82e73cd..5ee848c3cfa2117b2adeab5e563c5d07ce1d76ca 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc @@ -47,6 +47,7 @@ struct DataRecord { num_lines++; std::vector data; split(line, '\t', &data); + PADDLE_ENFORCE(data.size() >= 4); // load title1 data std::vector title1_data; split_to_int64(data[0], ' ', &title1_data); diff --git a/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc b/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc index bd4f1b61973fb0de06dcc288e329c94756d5ed47..a23297f29cf65d891f530850ffd184aa58e10886 100644 --- a/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc @@ -214,28 +214,23 @@ TEST(Analyzer_Transformer, fuse_statis) { } // Compare result of NativeConfig and AnalysisConfig -// void compare(bool use_mkldnn = false) { -// AnalysisConfig cfg; -// SetConfig(&cfg); -// if (use_mkldnn) { -// cfg.EnableMKLDNN(); -// } -// -// std::vector> input_slots_all; -// SetInput(&input_slots_all); -// CompareNativeAndAnalysis( -// reinterpret_cast(&cfg), -// input_slots_all); -// } - -// TODO(yihuaxu): -// Disable compare and compare_mkldnn temporary, see -// https://github.com/paddlePaddle/Paddle/issues/16316 for details. -// TEST(Analyzer_Transformer, compare) { compare(); } -// #ifdef PADDLE_WITH_MKLDNN -// TEST(Analyzer_Transformer, compare_mkldnn) { compare(true /* use_mkldnn */); -// } -// #endif +void compare(bool use_mkldnn = false) { + AnalysisConfig cfg; + SetConfig(&cfg); + if (use_mkldnn) { + cfg.EnableMKLDNN(); + } + + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); +} + +TEST(Analyzer_Transformer, compare) { compare(); } +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_Transformer, compare_mkldnn) { compare(true /* use_mkldnn */); } +#endif } // namespace inference } // namespace paddle diff --git a/paddle/fluid/op_use_default_grad_op_maker.spec b/paddle/fluid/op_use_default_grad_op_maker.spec index 21a25ce7d5e2bad172cf50cee6138ef4b44b07c1..63eaa676a43fc784dce2437ca15bc85e2295dbb7 100644 --- a/paddle/fluid/op_use_default_grad_op_maker.spec +++ b/paddle/fluid/op_use_default_grad_op_maker.spec @@ -29,6 +29,8 @@ pool3d prelu quantize rank_loss +reduce_all +reduce_any reduce_max reduce_mean reduce_min diff --git a/paddle/fluid/operators/detection/gpc.cc b/paddle/fluid/operators/detection/gpc.cc index 7c0823c0487d39eece5be08322e7d182b931ba3c..f46aaf7d0a7b2d48f18ba6cccb555bbb691ad353 100644 --- a/paddle/fluid/operators/detection/gpc.cc +++ b/paddle/fluid/operators/detection/gpc.cc @@ -24,6 +24,7 @@ **/ #include "paddle/fluid/operators/detection/gpc.h" +#include "paddle/fluid/platform/enforce.h" namespace gpc { @@ -689,6 +690,7 @@ static bbox *create_contour_bboxes(gpc_polygon *p) { gpc_malloc(box, p->num_contours * sizeof(bbox), const_cast("Bounding box creation")); + PADDLE_ENFORCE_NOT_NULL(box); /* Construct contour bounding boxes */ for (c = 0; c < p->num_contours; c++) { @@ -852,6 +854,7 @@ void gpc_add_contour(gpc_polygon *p, gpc_vertex_list *new_contour, int hole) { /* Create an extended hole array */ gpc_malloc(extended_hole, (p->num_contours + 1) * sizeof(int), const_cast("contour hole addition")); + PADDLE_ENFORCE_NOT_NULL(extended_hole); /* Create an extended contour array */ gpc_malloc(extended_contour, @@ -969,6 +972,7 @@ void gpc_polygon_clip(gpc_op op, gpc_polygon *subj, gpc_polygon *clip, /* Build scanbeam table from scanbeam tree */ gpc_malloc(sbt, sbt_entries * sizeof(double), const_cast("sbt creation")); + PADDLE_ENFORCE_NOT_NULL(sbt); build_sbt(&scanbeam, sbt, sbtree); scanbeam = 0; free_sbtree(&sbtree); @@ -1604,6 +1608,7 @@ void gpc_tristrip_clip(gpc_op op, gpc_polygon *subj, gpc_polygon *clip, /* Build scanbeam table from scanbeam tree */ gpc_malloc(sbt, sbt_entries * sizeof(double), const_cast("sbt creation")); + PADDLE_ENFORCE_NOT_NULL(sbt); build_sbt(&scanbeam, sbt, sbtree); scanbeam = 0; free_sbtree(&sbtree); diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4aeb062d8dfae31a72b8ebccb3d377276662da6 --- /dev/null +++ b/paddle/fluid/operators/linspace_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/linspace_op.h" + +namespace paddle { +namespace operators { + +class LinspaceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Start"), + "Input(Start) of LinspaceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Stop"), + "Input(Stop) of LinspaceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Num"), + "Input(Num) of LinspaceOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(OUt) of LinspaceOp should not be null."); + + auto s_dims = ctx->GetInputDim("Start"); + PADDLE_ENFORCE((s_dims.size() == 1) && (s_dims[0] == 1), + "The shape of Input(Start) should be [1]."); + + auto e_dims = ctx->GetInputDim("Stop"); + PADDLE_ENFORCE((e_dims.size() == 1) && (e_dims[0] == 1), + "The shape of Input(Stop) should be [1]."); + + auto step_dims = ctx->GetInputDim("Num"); + PADDLE_ENFORCE((step_dims.size() == 1) && (step_dims[0] == 1), + "The shape of Input(Num) should be [1]."); + + ctx->SetOutputDim("Out", {-1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; + return framework::OpKernelType( + ctx.Input("Start")->type(), ctx.device_context(), + layout_, library_); + } +}; + +class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Start", + "First entry in the sequence. It is a tensor of shape [1], should " + "be of type float32 or float64."); + AddInput("Stop", + "Last entry in the sequence. It is a tensor of shape [1], should " + "be of type float32 or float64."); + AddInput("Num", + "Number of entry in the sequence. It is a tensor of shape [1], " + "should be of type int32."); + AddOutput("Out", "A sequence of numbers."); + AddComment(R"DOC( + Return fixed number of evenly spaced values within a given interval. First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(linspace, ops::LinspaceOp, ops::LinspaceOpMaker); +REGISTER_OP_CPU_KERNEL(linspace, ops::CPULinspaceKernel, + ops::CPULinspaceKernel); diff --git a/paddle/fluid/operators/linspace_op.cu b/paddle/fluid/operators/linspace_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..90bd17cda0e0d1f78810233537bb502f9115fbd0 --- /dev/null +++ b/paddle/fluid/operators/linspace_op.cu @@ -0,0 +1,75 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/linspace_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) { + CUDA_1D_KERNEL_LOOP(index, size) { out[index] = start + step * index; } +} + +template +__global__ void LinspaceSpecialKernel(T start, T* out) { + out[0] = start; +} + +template +class CUDALinspaceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* start_t = context.Input("Start"); + auto* stop_t = context.Input("Stop"); + auto* num_t = context.Input("Num"); + auto* out = context.Output("Out"); + + framework::Tensor n; + framework::TensorCopy(*start_t, platform::CPUPlace(), &n); + T start = n.data()[0]; + framework::TensorCopy(*stop_t, platform::CPUPlace(), &n); + T stop = n.data()[0]; + framework::TensorCopy(*num_t, platform::CPUPlace(), &n); + int32_t num = n.data()[0]; + + PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0."); + + out->Resize(framework::make_ddim({num})); + T* out_data = out->mutable_data(context.GetPlace()); + + T step = 0; + if (num != 1) { + step = (stop - start) / (num - 1); + } + + auto stream = context.cuda_device_context().stream(); + int block = 512; + int grid = (num + block - 1) / block; + LinspaceKernel<<>>(start, step, num, out_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel, + ops::CUDALinspaceKernel); diff --git a/paddle/fluid/operators/linspace_op.h b/paddle/fluid/operators/linspace_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b1fcac73b0ad249aa19859bde770a8554cdb7408 --- /dev/null +++ b/paddle/fluid/operators/linspace_op.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class CPULinspaceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + T start = context.Input("Start")->data()[0]; + T stop = context.Input("Stop")->data()[0]; + int32_t num = context.Input("Num")->data()[0]; + auto* out = context.Output("Out"); + PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0."); + + out->Resize(framework::make_ddim({num})); + + T* out_data = out->mutable_data(context.GetPlace()); + + if (num > 1) { + T step = (stop - start) / (num - 1); + T value = start; + for (int i = 0; i < num; ++i) { + out_data[i] = value; + value += step; + } + } else { + out_data[0] = start; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_all_op.cc b/paddle/fluid/operators/reduce_ops/reduce_all_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b087fbbb94c7ba2f7449f6bda56010dee1c38ea6 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_all_op.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h" + +REGISTER_REDUCE_OP(reduce_all); +REGISTER_OP_CPU_KERNEL(reduce_all, + ops::ReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_all_op.cu b/paddle/fluid/operators/reduce_ops/reduce_all_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..bd94ba263d957d0d65506ecd802bf43add6e2fb4 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_all_op.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h" + +REGISTER_OP_CUDA_KERNEL(reduce_all, + ops::ReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_all_op.h b/paddle/fluid/operators/reduce_ops/reduce_all_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ba159dd703c8904784546eda262bf7be77967d48 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_all_op.h @@ -0,0 +1,29 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" + +namespace paddle { +namespace operators { + +struct AllFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->all(dim); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d865dcb3c935b76b8da25d723a5f780fb4de255b --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2018 PaddlePaddle Authors. Any Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h" + +REGISTER_REDUCE_OP(reduce_any); +REGISTER_OP_CPU_KERNEL(reduce_any, + ops::ReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.cu b/paddle/fluid/operators/reduce_ops/reduce_any_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..66f0c9997ea1e27cf172a6839a68d2eb23395c4d --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2018 PaddlePaddle Authors. Any Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h" + +REGISTER_OP_CUDA_KERNEL(reduce_any, + ops::ReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.h b/paddle/fluid/operators/reduce_ops/reduce_any_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b36bad9cada259932d2bd77c2426fbb46790de76 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.h @@ -0,0 +1,29 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" + +namespace paddle { +namespace operators { + +struct AnyFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->any(dim); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/squared_l2_distance_op.h b/paddle/fluid/operators/squared_l2_distance_op.h index e0133d33e6a840d2d06832393a064df978cb9cbc..12a8f05b5a603417ead8ebd250ff7951f928f4a1 100644 --- a/paddle/fluid/operators/squared_l2_distance_op.h +++ b/paddle/fluid/operators/squared_l2_distance_op.h @@ -77,6 +77,9 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { auto* x_g = context.Output(framework::GradVarName("X")); auto* y_g = context.Output(framework::GradVarName("Y")); + PADDLE_ENFORCE_NOT_NULL(x_g); + PADDLE_ENFORCE_NOT_NULL(y_g); + auto sub_result = EigenMatrix::From(*in0); auto out_grad = EigenMatrix::From(*in1); @@ -92,31 +95,28 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { // propagate back to input auto& eigen_place = *context.template device_context().eigen_device(); - if (x_g) { - x_g->mutable_data(context.GetPlace()); - // eigen matrix - auto x_grad = - EigenMatrix::From(*x_g, framework::make_ddim({x_dims[0], cols})); - // dimensions are same with subResult - x_grad.device(eigen_place) = grad_mat; - } - if (y_g) { - y_g->mutable_data(context.GetPlace()); - - PADDLE_ENFORCE_GE(sub_result.dimensions()[0], y_dims[0], - "First dimension of gradient must be greater or " - "equal than first dimension of target."); - - if (sub_result.dimensions()[0] == y_dims[0]) { - auto y_grad = - EigenMatrix::From(*y_g, framework::make_ddim({y_dims[0], cols})); - y_grad.device(eigen_place) = -1 * grad_mat; - } else { - auto col_sum_res = -1 * (grad_mat.sum(Eigen::array({{0}}))); - auto y_grad = EigenVector::Flatten(*y_g); - y_grad.device(eigen_place) = col_sum_res; - } + x_g->mutable_data(context.GetPlace()); + // eigen matrix + auto x_grad = + EigenMatrix::From(*x_g, framework::make_ddim({x_dims[0], cols})); + // dimensions are same with subResult + x_grad.device(eigen_place) = grad_mat; + + y_g->mutable_data(context.GetPlace()); + + PADDLE_ENFORCE_GE(sub_result.dimensions()[0], y_dims[0], + "First dimension of gradient must be greater or " + "equal than first dimension of target."); + + if (sub_result.dimensions()[0] == y_dims[0]) { + auto y_grad = + EigenMatrix::From(*y_g, framework::make_ddim({y_dims[0], cols})); + y_grad.device(eigen_place) = -1 * grad_mat; + } else { + auto col_sum_res = -1 * (grad_mat.sum(Eigen::array({{0}}))); + auto y_grad = EigenVector::Flatten(*y_g); + y_grad.device(eigen_place) = col_sum_res; } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index eaa07477d85be0674dd654097849cf6d3f0ac442..93e46eef16fb177169db679a8437d9a33ed38e99 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -73,6 +73,8 @@ __all__ = [ 'reduce_max', 'reduce_min', 'reduce_prod', + 'reduce_all', + 'reduce_any', 'sequence_first_step', 'sequence_last_step', 'sequence_slice', @@ -159,6 +161,7 @@ __all__ = [ 'sum', 'slice', 'shape', + 'rank', 'logical_and', 'logical_or', 'logical_xor', @@ -4738,6 +4741,106 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): return out +def reduce_all(input, dim=None, keep_dim=False, name=None): + """ + Computes the ``logical and`` of tensor elements over the given dimension. + + Args: + input (Variable): The input variable which is a Tensor or LoDTensor. + dim (list|int|None): The dimension along which the logical and is computed. + If :attr:`None`, compute the logical and over all elements of + :attr:`input` and return a Tensor variable with a single element, + otherwise must be in the range :math:`[-rank(input), rank(input))`. + If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`. + keep_dim (bool): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the :attr:`input` unless :attr:`keep_dim` is true. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The reduced Tensor variable. + + Examples: + .. code-block:: python + + # x is a bool Tensor variable with following elements: + # [[True, False] + # [True, True]] + # Each example is followed by the correspending output tensor. + fluid.layers.reduce_all(x) # False + fluid.layers.reduce_all(x, dim=0) # [True, False] + fluid.layers.reduce_all(x, dim=-1) # [False, True] + fluid.layers.reduce_all(x, dim=1, + keep_dim=True) # [[False], [True]] + + """ + helper = LayerHelper('reduce_all', **locals()) + out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + if dim is not None and not isinstance(dim, list): + dim = [dim] + helper.append_op( + type='reduce_all', + inputs={'X': input}, + outputs={'Out': out}, + attrs={ + 'dim': dim if dim != None else [0], + 'keep_dim': keep_dim, + 'reduce_all': True if dim == None else False + }) + return out + + +def reduce_any(input, dim=None, keep_dim=False, name=None): + """ + Computes the ``logical or`` of tensor elements over the given dimension. + + Args: + input (Variable): The input variable which is a Tensor or LoDTensor. + dim (list|int|None): The dimension along which the logical or is computed. + If :attr:`None`, compute the logical or over all elements of + :attr:`input` and return a Tensor variable with a single element, + otherwise must be in the range :math:`[-rank(input), rank(input))`. + If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`. + keep_dim (bool): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the :attr:`input` unless :attr:`keep_dim` is true. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The reduced Tensor variable. + + Examples: + .. code-block:: python + + # x is a bool Tensor variable with following elements: + # [[True, False] + # [False, False]] + # Each example is followed by the correspending output tensor. + fluid.layers.reduce_any(x) # True + fluid.layers.reduce_any(x, dim=0) # [True, False] + fluid.layers.reduce_any(x, dim=-1) # [True, False] + fluid.layers.reduce_any(x, dim=1, + keep_dim=True) # [[True], [False]] + + """ + helper = LayerHelper('reduce_any', **locals()) + out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + if dim is not None and not isinstance(dim, list): + dim = [dim] + helper.append_op( + type='reduce_any', + inputs={'X': input}, + outputs={'Out': out}, + attrs={ + 'dim': dim if dim != None else [0], + 'keep_dim': keep_dim, + 'reduce_all': True if dim == None else False + }) + return out + + def split(input, num_or_sections, dim=-1, name=None): """ Split the input tensor into multiple sub-tensors. @@ -9237,6 +9340,32 @@ def shape(input): return out +def rank(input): + """ + **Rank Layer** + + Returns the number of dimensions for a tensor, which is a 0-D int32 Tensor. + + Args: + input (Variable): The input variable. + + Returns: + Variable: The rank of the input variable. + + Examples: + .. code-block:: python + + input = layers.data( + name="input", shape=[3, 100, 100], dtype="float32") + rank = layers.rank(input) # 4 + """ + + ndims = len(input.shape) + out = assign(np.array(ndims, 'int32')) + + return out + + def _elementwise_op(helper): op_type = helper.layer_type x = helper.kwargs.get('x', None) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 80450119f44e93aae4b483983484ea18be5b2035..03ebd41fa00c69bfce66d325e32fc9aeb25a2486 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -24,26 +24,11 @@ from .layer_function_generator import templatedoc import numpy __all__ = [ - 'create_tensor', - 'create_parameter', - 'create_global_var', - 'cast', - 'tensor_array_to_tensor', - 'concat', - 'sums', - 'assign', - 'fill_constant_batch_size_like', - 'fill_constant', - 'argmin', - 'argmax', - 'argsort', - 'ones', - 'zeros', - 'reverse', - 'has_inf', - 'has_nan', - 'isfinite', - 'range', + 'create_tensor', 'create_parameter', 'create_global_var', 'cast', + 'tensor_array_to_tensor', 'concat', 'sums', 'assign', + 'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax', + 'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite', + 'range', 'linspace' ] @@ -826,3 +811,45 @@ def range(start, end, step, dtype): 'Step': step}, outputs={'Out': [out]}) return out + + +def linspace(start, stop, num, dtype): + """ + Return fixed number of evenly spaced values within a given interval. + + First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy. + + Args: + start(float|Variable): First entry in the sequence. It is a float scalar, or a tensor of shape [1] with type 'float32'|'float64'. + stop(float|Variable): Last entry in the sequence. It is a float scalar, or a tensor of shape [1] with type 'float32'|'float64'. + num(int|Variable): Number of entry in the sequence. It is an int scalar, or a tensor of shape [1] with type int32. + dtype(string): 'float32'|'float64', the data type of the output tensor. + + Returns: + Variable: The tensor variable storing a 1-D tensor. + + Examples: + .. code-block:: python + + data = fluid.layers.linspace(0, 10, 5, 'float32') # [0.0, 2.5, 5.0, 7.5, 10.0] + data = fluid.layers.linspace(0, 10, 1, 'float32') # [0.0] + + """ + helper = LayerHelper("linspace", **locals()) + + if not isinstance(start, Variable): + start = fill_constant([1], dtype, start) + if not isinstance(stop, Variable): + stop = fill_constant([1], dtype, stop) + if not isinstance(num, Variable): + num = fill_constant([1], 'int32', num) + + out = helper.create_variable_for_type_inference(dtype=start.dtype) + + helper.append_op( + type='linspace', + inputs={'Start': start, + 'Stop': stop, + 'Num': num}, + outputs={'Out': [out]}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 38d0533a7ec820241c6a08f2180a7426984068f2..6630fb26aff9a8c570e65c34a753595da883bea1 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1925,6 +1925,13 @@ class TestBook(LayerTest): out = layers.flatten(x, axis=1, name="flatten") return (out) + def test_linspace(self): + program = Program() + with program_guard(program): + out = layers.linspace(20, 10, 5, 'float64') + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_linspace.py b/python/paddle/fluid/tests/unittests/test_linspace.py new file mode 100644 index 0000000000000000000000000000000000000000..eeecf178320327cc251f32bfe46c1622200339f4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_linspace.py @@ -0,0 +1,71 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestLinspaceOpCommonCase(OpTest): + def setUp(self): + self.op_type = "linspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([0]).astype(dtype), + 'Stop': np.array([10]).astype(dtype), + 'Num': np.array([11]).astype('int32') + } + + self.outputs = {'Out': np.arange(0, 11).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLinspaceOpReverseCase(OpTest): + def setUp(self): + self.op_type = "linspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([10]).astype(dtype), + 'Stop': np.array([0]).astype(dtype), + 'Num': np.array([11]).astype('int32') + } + + self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLinspaceOpNumOneCase(OpTest): + def setUp(self): + self.op_type = "linspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([10]).astype(dtype), + 'Stop': np.array([0]).astype(dtype), + 'Num': np.array([1]).astype('int32') + } + + self.outputs = {'Out': np.array(10, dtype=dtype)} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 8fc8125a773543eea768783155ad152c475535b5..65fc1453d8db13ad9c85746c3bf148f898e8f788 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -91,6 +91,78 @@ class TestProdOp(OpTest): self.check_grad(['X'], 'Out') +class TestAllOp(OpTest): + def setUp(self): + self.op_type = "reduce_all" + self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} + self.outputs = {'Out': self.inputs['X'].all()} + self.attrs = {'reduce_all': True} + + def test_check_output(self): + self.check_output() + + +class TestAllOpWithDim(OpTest): + def setUp(self): + self.op_type = "reduce_all" + self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} + self.attrs = {'dim': [1]} + self.outputs = {'Out': self.inputs['X'].all(axis=1)} + + def test_check_output(self): + self.check_output() + + +class TestAllOpWithKeepDim(OpTest): + def setUp(self): + self.op_type = "reduce_all" + self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} + self.attrs = {'dim': [1], 'keep_dim': True} + self.outputs = { + 'Out': np.expand_dims( + self.inputs['X'].all(axis=1), axis=1) + } + + def test_check_output(self): + self.check_output() + + +class TestAnyOp(OpTest): + def setUp(self): + self.op_type = "reduce_any" + self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} + self.outputs = {'Out': self.inputs['X'].any()} + self.attrs = {'reduce_all': True} + + def test_check_output(self): + self.check_output() + + +class TestAnyOpWithDim(OpTest): + def setUp(self): + self.op_type = "reduce_any" + self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} + self.attrs = {'dim': [1]} + self.outputs = {'Out': self.inputs['X'].any(axis=1)} + + def test_check_output(self): + self.check_output() + + +class TestAnyOpWithKeepDim(OpTest): + def setUp(self): + self.op_type = "reduce_any" + self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")} + self.attrs = {'dim': [1], 'keep_dim': True} + self.outputs = { + 'Out': np.expand_dims( + self.inputs['X'].any(axis=1), axis=1) + } + + def test_check_output(self): + self.check_output() + + class Test1DReduce(OpTest): def setUp(self): self.op_type = "reduce_sum"