From df3ae18a74df234434330c20eebe0368309b541d Mon Sep 17 00:00:00 2001 From: Allen Guo Date: Mon, 21 Mar 2022 10:37:46 +0800 Subject: [PATCH] [IPU] add more ops (#40691) * add more ops * add authors Co-authored-by: Xiaobing Wang Co-authored-by: Allen Guo Co-authored-by: Zhixin Yao Co-authored-by: Zhaorui Chen Co-authored-by: Han Zhao * rm ipu_strategy.check() * fix UT fail * fix typo Co-authored-by: Xiaobing Wang Co-authored-by: Zhixin Yao Co-authored-by: Zhaorui Chen Co-authored-by: Han Zhao --- .../ipu/popart_canonicalization/logic_ops.cc | 21 +++ .../ipu/popart_canonicalization/math_ops.cc | 6 + .../ipu/popart_canonicalization/other_ops.cc | 26 ++++ .../ipu/popart_canonicalization/tensor_ops.cc | 79 ++++++++++++ python/paddle/fluid/compiler.py | 69 +++++++++- .../tests/unittests/ipu/test_assign_op_ipu.py | 112 ++++++++++++++++ .../unittests/ipu/test_greater_op_ipu.py | 87 ++++++++----- .../unittests/ipu/test_ipu_shard_api_ipu.py | 82 ++++++------ .../unittests/ipu/test_logical_x_op_ipu.py | 121 ++++++++++++++++++ .../unittests/ipu/test_one_hot_op_ipu.py | 110 ++++++++++++++++ .../unittests/ipu/test_one_hot_v2_op_ipu.py | 110 ++++++++++++++++ .../tests/unittests/ipu/test_optimizer_ipu.py | 53 ++++++++ .../unittests/ipu/test_set_batch_size_ipu.py | 5 +- 13 files changed, 802 insertions(+), 79 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ipu/test_logical_x_op_ipu.py create mode 100644 python/paddle/fluid/tests/unittests/ipu/test_one_hot_op_ipu.py create mode 100644 python/paddle/fluid/tests/unittests/ipu/test_one_hot_v2_op_ipu.py diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc index c980bb780cf..7d928355345 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc @@ -34,15 +34,36 @@ Node *logical_not_handler(Graph *graph, Node *node) { {GetOutputVarNode("Out", node)}, {}); } +Node *logical_or_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_logical_or", + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + {GetOutputVarNode("Out", node)}, {}); +} + +Node *logical_and_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_logical_and", + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + {GetOutputVarNode("Out", node)}, {}); +} + Node *greater_than_handler(Graph *graph, Node *node) { return CreateBaseOp(graph, node, "popart_greater", {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, {GetOutputVarNode("Out", node)}, {}); } +Node *less_than_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_less", + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + {GetOutputVarNode("Out", node)}, {}); +} + REGISTER_HANDLER(equal, equal_handler); REGISTER_HANDLER(logical_not, logical_not_handler); +REGISTER_HANDLER(logical_or, logical_or_handler); +REGISTER_HANDLER(logical_and, logical_and_handler); REGISTER_HANDLER(greater_than, greater_than_handler); +REGISTER_HANDLER(less_than, less_than_handler); } // namespace } // namespace ipu diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc index d4a14a6d840..ba6675f40f4 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc @@ -98,6 +98,12 @@ Node *matmul_handler(Graph *graph, Node *node) { if (x_rank == 1) { perm = std::vector{0}; } else if (x_rank == 2) { + if (!transpose_x && !transpose_y && is_float_equal(alpha, 1.0f)) { + return CreateBaseOp( + graph, node, "popart_matmul", + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + node->outputs); + } return CreateGemm(graph, node, {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, node->outputs, transpose_x, transpose_y, alpha); diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc index 0919afef4d8..8bd07943688 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc @@ -54,10 +54,36 @@ Node *checkpointoutput_handler(Graph *graph, Node *node) { node->outputs); } +Node *custom_nll_loss_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction")); + auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignoreIndex")); + auto inputIsLogProbability = + BOOST_GET_CONST(bool, op->GetAttr("inputIsLogProbability")); + return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs, + node->outputs, + {{"reduction", reduction}, + {"ignoreIndex", ignoreIndex}, + {"inputIsLogProbability", inputIsLogProbability}}); +} + +Node *identity_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_identity", node->inputs, + node->outputs); +} + +Node *detach_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_detach_v2", node->inputs, + node->outputs); +} + REGISTER_HANDLER(custom_op, custom_op_handler); REGISTER_HANDLER(print, print_handler); REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler); REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler); +REGISTER_HANDLER(custom_nll_loss, custom_nll_loss_handler); +REGISTER_HANDLER(identity, identity_handler); +REGISTER_HANDLER(detach, detach_handler); } // namespace } // namespace ipu diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc index db429d2f622..6ccb5441f83 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc @@ -49,6 +49,9 @@ Node *fill_constant_handler(Graph *graph, Node *node) { case framework::proto::VarType::INT64: value = std::vector(size, value_); break; + case framework::proto::VarType::BOOL: + value = std::vector(size, value_); + break; default: PADDLE_THROW( platform::errors::Unimplemented("fill_constant dtype: %d", dtype_)); @@ -417,6 +420,45 @@ Node *assign_handler(Graph *graph, Node *node) { {GetOutputVarNode("Out", node)}, {}); } +Node *assign_value_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); + auto dtype = VarType2OnnxDtype(dtype_); + auto dims_ = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); + std::vector dims(dims_.begin(), dims_.end()); + Attribute values; + std::string value_name; + switch (dtype_) { + case framework::proto::VarType::BOOL: { + value_name = "bool_values"; + auto vec_int = BOOST_GET_CONST(std::vector, op->GetAttr(value_name)); + std::vector vec_bool(vec_int.begin(), vec_int.end()); + values = vec_bool; + } break; + case framework::proto::VarType::INT32: + value_name = "int32_values"; + values = BOOST_GET_CONST(std::vector, op->GetAttr(value_name)); + break; + case framework::proto::VarType::FP32: + value_name = "fp32_values"; + values = BOOST_GET_CONST(std::vector, op->GetAttr(value_name)); + break; + case framework::proto::VarType::INT64: + value_name = "int64_values"; + values = BOOST_GET_CONST(std::vector, op->GetAttr(value_name)); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported data type(code %d) for AssignValue operator, only " + "supports bool, int32, float32 and int64.", + dtype)); + } + return CreateConst(graph, node, node->inputs, node->outputs, + AttributeMap{ + {"value", values}, {"dims", dims}, {"dtype", dtype}, + }); +} + Node *fill_any_like_handler(Graph *graph, Node *node) { auto *op = node->Op(); auto value = BOOST_GET_CONST(float, op->GetAttr("value")); @@ -482,6 +524,41 @@ Node *one_hot_handler(Graph *graph, Node *node) { } } +Node *one_hot_v2_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto depth = BOOST_GET_CONST(int, op->GetAttr("depth")); + auto allow_out_of_range = + BOOST_GET_CONST(bool, op->GetAttr("allow_out_of_range")); + if (allow_out_of_range) { + PADDLE_THROW(platform::errors::Unimplemented( + "Do not support allow_out_of_range=True")); + } else { + auto depth_tensor = + CreateConst(graph, node, {}, {}, {{"value", std::vector{depth}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT32}}); + Node *value_tensor = nullptr; + if (GetOutputVarNode("Out", node)->Var()->GetDataType() == + framework::proto::VarType::FP16) { + value_tensor = + CreateConst(graph, node, {}, {}, {{"value", std::vector{0, 1}}, + {"dims", std::vector{2}}, + {"dtype", ONNXDataType::FLOAT16}}); + } else { + value_tensor = + CreateConst(graph, node, {}, {}, {{"value", std::vector{0, 1}}, + {"dims", std::vector{2}}, + {"dtype", ONNXDataType::FLOAT}}); + } + + return CreateBaseOp(graph, node, "popart_onehot", + {GetInputVarNode("X", node), depth_tensor->outputs[0], + value_tensor->outputs[0]}, + {GetOutputVarNode("Out", node)}, + {{"axis", int64_t{-1}}}); + } +} + Node *split_handler(Graph *graph, Node *node) { auto *op = node->Op(); auto axis = BOOST_GET_CONST(int, op->GetAttr("axis")); @@ -510,10 +587,12 @@ REGISTER_HANDLER(shape, shape_handler); REGISTER_HANDLER(slice, slice_handler); REGISTER_HANDLER(expand, expand_handler); REGISTER_HANDLER(assign, assign_handler); +REGISTER_HANDLER(assign_value, assign_value_handler); REGISTER_HANDLER(fill_any_like, fill_any_like_handler); REGISTER_HANDLER(lookup_table_v2, lookup_table_v2_handler); REGISTER_HANDLER(split, split_handler); REGISTER_HANDLER(one_hot, one_hot_handler); +REGISTER_HANDLER(one_hot_v2, one_hot_v2_handler); } // namespace } // namespace ipu diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index b8a696057e7..d21b7e4740a 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -542,7 +542,7 @@ class IpuStrategy(object): def set_graph_config(self, num_ipus=1, is_training=True, - batch_size=1, + micro_batch_size=1, enable_manual_shard=False): """ Set graph configuration to the IpuStrategy instance. @@ -571,7 +571,7 @@ class IpuStrategy(object): ipu_strategy = static.IpuStrategy() ipu_strategy.set_graph_config(num_ipus=1, is_training=True, - batch_size=1, + micro_batch_size=1, enable_manual_shard=False) """ if num_ipus == 1 and enable_manual_shard: @@ -581,7 +581,7 @@ class IpuStrategy(object): options = { 'num_ipus': num_ipus, 'is_training': is_training, - 'micro_batch_size': batch_size, + 'micro_batch_size': micro_batch_size, 'enable_manual_shard': enable_manual_shard, } self.set_options(options) @@ -589,6 +589,7 @@ class IpuStrategy(object): def set_pipelining_config(self, enable_pipelining=False, batches_per_step=1, + enable_gradient_accumulation=False, accumulation_factor=1): """ Set pipelining configuration to the IpuStrategy instance. Used to optimize the throughput performance. @@ -598,6 +599,8 @@ class IpuStrategy(object): Default False, which means disabled. batches_per_step (int, optional): Set the batches per run in data pipelining mode. Only if enable_pipelining=True, batches_per_step is able to be set > 1. Default 1, which means no data pipelining. + enable_gradient_accumulation (bool, optional): Enable to accumulate gradients before updating the weights in training mode. Only if enable_pipelining=True, + enable_gradient_accumulation is able to be set True. Default False, which means no gradient accumulation. accumulation_factor (int, optional): Specify the number of micro-batches to accumulate before applying the varUpdate. Default 1, which means disable the accumulation. @@ -617,6 +620,7 @@ class IpuStrategy(object): ipu_strategy = static.IpuStrategy() ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, + enable_gradient_accumulation=False, accumulation_factor=1) """ enable_manual_shard = self.get_option('enable_manual_shard') @@ -627,6 +631,7 @@ class IpuStrategy(object): options = { 'enable_pipelining': enable_pipelining, 'batches_per_step': batches_per_step, + 'enable_gradient_accumulation': enable_gradient_accumulation, 'accumulation_factor': accumulation_factor, } self.set_options(options) @@ -754,6 +759,56 @@ class IpuStrategy(object): """ return self._ipu_strategy.get_option(option)['value'] + def enable_pattern(self, pattern): + """ + Enable PopART pattern to optimize the graph. + + Args: + pattern(string): the name of the pattern. + + Returns: + None. + + Examples: + .. code-block:: python + + # required: ipu + + import paddle + import paddle.static as static + + paddle.enable_static() + + ipu_strategy = static.IpuStrategy() + ipu_strategy.enable_pattern("ViewSimplifyPattern") + """ + self._ipu_strategy.enable_pattern(pattern) + + def disable_pattern(self, pattern): + """ + Disable PopART pattern. + + Args: + pattern(string): the name of the pattern. + + Returns: + None. + + Examples: + .. code-block:: python + + # required: ipu + + import paddle + import paddle.static as static + + paddle.enable_static() + + ipu_strategy = static.IpuStrategy() + ipu_strategy.disable_pattern("ViewSimplifyPattern") + """ + self._ipu_strategy.disable_pattern(pattern) + @property def num_ipus(self): """ @@ -817,8 +872,8 @@ class IpuCompiledProgram(object): main_prog = static.default_main_program() ipu_strategy = static.IpuStrategy() - ipu_strategy.set_graph_config(num_ipus=1, is_training=True, batch_size=1) - ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, accumulation_factor=1) + ipu_strategy.set_graph_config(num_ipus=1, is_training=True, micro_batch_size=1) + ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, enable_gradient_accumulation=False, accumulation_factor=1) ipu_strategy.set_precision_config(enable_fp16=False) ipu_compiled_program = static.IpuCompiledProgram( @@ -891,8 +946,8 @@ class IpuCompiledProgram(object): main_prog = static.default_main_program() ipu_strategy = static.IpuStrategy() - ipu_strategy.set_graph_config(num_ipus=1, is_training=True, batch_size=1) - ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, accumulation_factor=1) + ipu_strategy.set_graph_config(num_ipus=1, is_training=True, micro_batch_size=1) + ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, enable_gradient_accumulation=False, accumulation_factor=1) ipu_strategy.set_precision_config(enable_fp16=False) program = static.IpuCompiledProgram( diff --git a/python/paddle/fluid/tests/unittests/ipu/test_assign_op_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_assign_op_ipu.py index 4f17c90de72..35f4ca17d5e 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_assign_op_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_assign_op_ipu.py @@ -98,5 +98,117 @@ class TestBase(IPUOpTest): self.check(output_dict) +class TestAssignFp32Value(TestBase): + def set_data_feed(self): + data = np.random.uniform(size=[2, 3, 1]) + self.feed_fp32 = {'in_0': data.astype(np.float32)} + self.feed_fp16 = {'in_0': data.astype(np.float16)} + + data = np.random.uniform(size=[2, 3, 1]) + self.assign_fp32 = data.astype(np.float32) + + def _test_base(self, exec_mode): + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = self.SEED + startup_prog.random_seed = self.SEED + + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data( + name=self.feed_list[0], + shape=self.feed_shape[0], + dtype='float32') + + assign = paddle.assign(self.assign_fp32) + out = paddle.fluid.layers.elementwise_add(x, assign) + + fetch_list = [out.name] + + if exec_mode == ExecutionMode.CPU_FP32: + place = paddle.CPUPlace() + else: + place = paddle.IPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + if exec_mode != ExecutionMode.CPU_FP32: + feed_list = self.feed_list + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config(is_training=self.is_training) + if exec_mode == ExecutionMode.IPU_POPART_FP16: + ipu_strategy.set_precision_config(enable_fp16=True) + program = paddle.static.IpuCompiledProgram( + main_prog, + ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) + else: + program = main_prog + + feed = self.feed_fp32 + if exec_mode > ExecutionMode.IPU_FP32: + feed = self.feed_fp16 + + result = exe.run(program, feed=feed, fetch_list=fetch_list) + return result[0] + + +class TestAssignBoolValue(TestBase): + def set_data_feed(self): + data = np.random.uniform(size=[2, 3, 1]) + self.feed_fp32 = {'in_0': data.astype(np.float32)} + self.feed_fp16 = {'in_0': data.astype(np.float16)} + data = np.random.choice([True, False], size=(2, 3, 1)) + self.assign_bool = data.astype(np.bool) + + def _test_base(self, exec_mode): + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = self.SEED + startup_prog.random_seed = self.SEED + + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data( + name=self.feed_list[0], + shape=self.feed_shape[0], + dtype='float32') + x = paddle.less_than(x, x) + assign = paddle.assign(self.assign_bool) + out = paddle.logical_and(x, assign) + out = paddle.cast(out, 'float32') + + fetch_list = [out.name] + + if exec_mode == ExecutionMode.CPU_FP32: + place = paddle.CPUPlace() + else: + place = paddle.IPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + if exec_mode != ExecutionMode.CPU_FP32: + feed_list = self.feed_list + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config(is_training=self.is_training) + if exec_mode == ExecutionMode.IPU_POPART_FP16: + ipu_strategy.set_precision_config(enable_fp16=True) + program = paddle.static.IpuCompiledProgram( + main_prog, + ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) + else: + program = main_prog + + feed = self.feed_fp32 + if exec_mode > ExecutionMode.IPU_FP32: + feed = self.feed_fp16 + + result = exe.run(program, feed=feed, fetch_list=fetch_list) + return result[0] + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ipu/test_greater_op_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_greater_op_ipu.py index 05a37dcb3d5..934ad101428 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_greater_op_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_greater_op_ipu.py @@ -22,33 +22,18 @@ from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionMod @unittest.skipIf(not paddle.is_compiled_with_ipu(), "core is not compiled with IPU") -class TestBase(IPUOpTest): +class TestGreaterThan(IPUOpTest): def setUp(self): self.set_atol() self.set_training() - self.set_data_feed() - self.set_feed_attr() - self.set_op_attrs() + self.set_test_op() @property def fp16_enabled(self): return True - def set_data_feed(self): - x = np.random.randn(3, 4, 5) - y = np.random.randn(3, 4, 5) - self.feed_fp32 = { - "x": x.astype(np.float32), - "y": y.astype(np.float32), - } - self.feed_fp16 = { - "x": x.astype(np.float16), - "y": y.astype(np.float16), - } - - def set_feed_attr(self): - self.feed_shape = [x.shape for x in self.feed_fp32.values()] - self.feed_list = list(self.feed_fp32.keys()) + def set_test_op(self): + self.op = paddle.fluid.layers.greater_than def set_op_attrs(self): self.attrs = {} @@ -71,7 +56,7 @@ class TestBase(IPUOpTest): shape=self.feed_shape[1], dtype='float32') - out = paddle.fluid.layers.greater_than(x, y, **self.attrs) + out = self.op(x, y, **self.attrs) fetch_list = [out.name] @@ -102,7 +87,7 @@ class TestBase(IPUOpTest): result = exe.run(program, feed=feed, fetch_list=fetch_list) return result[0] - def test(self): + def run_test_base(self): output_dict = {} for mode in ExecutionMode: if mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled: @@ -111,29 +96,73 @@ class TestBase(IPUOpTest): self.check(output_dict) + def set_feed_attr(self): + self.feed_shape = [x.shape for x in self.feed_fp32.values()] + self.feed_list = list(self.feed_fp32.keys()) + + def set_data_feed0(self): + x = np.random.randn(3, 4, 5) + y = np.random.randn(3, 4, 5) + self.feed_fp32 = { + "x": x.astype(np.float32), + "y": y.astype(np.float32), + } + self.feed_fp16 = { + "x": x.astype(np.float16), + "y": y.astype(np.float16), + } + self.set_feed_attr() -class TestCase1(TestBase): - def set_data_feed(self): + def set_data_feed1(self): x = np.ones([1, 10]) y = np.ones([10]) self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)} self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)} + self.set_feed_attr() - -class TestCase2(TestBase): - def set_data_feed(self): + def set_data_feed2(self): x = np.ones([1, 10]) y = np.zeros([1, 10]) self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)} self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)} + self.set_feed_attr() - -class TestCase3(TestBase): - def set_data_feed(self): + def set_data_feed3(self): x = np.zeros([1, 10]) y = np.ones([1, 10]) self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)} self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)} + self.set_feed_attr() + + def test_case0(self): + self.set_data_feed0() + self.set_op_attrs() + self.run_test_base() + + def test_case1(self): + self.set_data_feed1() + self.set_op_attrs() + self.run_test_base() + + def test_case2(self): + self.set_data_feed2() + self.set_op_attrs() + self.run_test_base() + + def test_case3(self): + self.set_data_feed3() + self.set_op_attrs() + self.run_test_base() + + +class TestLessThan(TestGreaterThan): + def set_test_op(self): + self.op = paddle.fluid.layers.less_than + + +class TestEqual(TestGreaterThan): + def set_test_op(self): + self.op = paddle.fluid.layers.equal if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ipu/test_ipu_shard_api_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_ipu_shard_api_ipu.py index 026b19eccf1..76ab1a2c3f3 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_ipu_shard_api_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_ipu_shard_api_ipu.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function +import unittest import numpy as np -import unittest import paddle +import paddle.static paddle.enable_static() @@ -26,30 +26,31 @@ paddle.enable_static() class TestIpuShard(unittest.TestCase): def _test(self): # build graph - a = paddle.static.data(name='data', shape=[None, 1], dtype='int32') - b = a + 2 # scale : scale * x + bias, ipu_index : no + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog): + a = paddle.static.data(name='data', shape=[None, 1], dtype='int32') + b = a + 2 # scale : scale * x + bias, ipu_index : no + + with paddle.static.ipu_shard_guard(index=1): + c = b + 1 # scale, ipu_index : 1 + with paddle.static.ipu_shard_guard(index=2): + d = c * 2 # scale, ipu_index : 2 + with paddle.static.ipu_shard_guard(index=3): + e = d + 3 # scale, ipu_index : 3 + with paddle.static.ipu_shard_guard(index=1): + e = e + 3 # scale, ipu_index : 1 + with paddle.static.ipu_shard_guard(index=2): + e = e + 3 # scale, ipu_index : 2 + + with paddle.static.ipu_shard_guard(index=1): + f = paddle.tensor.pow(e, 2.0) # pow, ipu_index : 1 - with paddle.static.ipu_shard_guard(index=1): - c = b + 1 # scale, ipu_index : 1 with paddle.static.ipu_shard_guard(index=2): - d = c * 2 # scale, ipu_index : 2 - with paddle.static.ipu_shard_guard(index=3): - e = d + 3 # scale, ipu_index : 3 - with paddle.static.ipu_shard_guard(index=1): - e = e + 3 # scale, ipu_index : 1 - with paddle.static.ipu_shard_guard(index=2): - e = e + 3 # scale, ipu_index : 2 - - with paddle.static.ipu_shard_guard(index=1): - f = paddle.tensor.pow(e, 2.0) # pow, ipu_index : 1 + g = f - 1 # scale, ipu_index : 2 - with paddle.static.ipu_shard_guard(index=2): - g = f - 1 # scale, ipu_index : 2 - - h = g + 1 # scale, ipu_index : no + h = g + 1 # scale, ipu_index : no ipu_index_list = [] - main_prog = paddle.static.default_main_program() for op in main_prog.global_block().ops: if op.desc.has_attr("ipu_index"): ipu_index_list.append(op.desc.attr("ipu_index")) @@ -69,30 +70,31 @@ class TestIpuShard(unittest.TestCase): class TestIpuPipeline(unittest.TestCase): def _test(self): # build graph - a = paddle.static.data(name='data', shape=[None, 1], dtype='int32') - b = a + 2 # scale : scale * x + bias, ipu_stage : no + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog): + a = paddle.static.data(name='data', shape=[None, 1], dtype='int32') + b = a + 2 # scale : scale * x + bias, ipu_stage : no + + with paddle.static.ipu_shard_guard(stage=1): + c = b + 1 # scale, ipu_stage : 1 + with paddle.static.ipu_shard_guard(stage=2): + d = c * 2 # scale, ipu_stage : 2 + with paddle.static.ipu_shard_guard(stage=3): + e = d + 3 # scale, ipu_stage : 3 + with paddle.static.ipu_shard_guard(stage=1): + e = e + 3 # scale, ipu_stage : 1 + with paddle.static.ipu_shard_guard(stage=2): + e = e + 3 # scale, ipu_stage : 2 + + with paddle.static.ipu_shard_guard(stage=1): + f = paddle.tensor.pow(e, 2.0) # pow, ipu_stage : 1 - with paddle.static.ipu_shard_guard(stage=1): - c = b + 1 # scale, ipu_stage : 1 with paddle.static.ipu_shard_guard(stage=2): - d = c * 2 # scale, ipu_stage : 2 - with paddle.static.ipu_shard_guard(stage=3): - e = d + 3 # scale, ipu_stage : 3 - with paddle.static.ipu_shard_guard(stage=1): - e = e + 3 # scale, ipu_stage : 1 - with paddle.static.ipu_shard_guard(stage=2): - e = e + 3 # scale, ipu_stage : 2 - - with paddle.static.ipu_shard_guard(stage=1): - f = paddle.tensor.pow(e, 2.0) # pow, ipu_stage : 1 - - with paddle.static.ipu_shard_guard(stage=2): - g = f - 1 # scale, ipu_stage : 2 + g = f - 1 # scale, ipu_stage : 2 - h = g + 1 # scale, ipu_stage : no + h = g + 1 # scale, ipu_stage : no ipu_index_list = [] - main_prog = paddle.static.default_main_program() for op in main_prog.global_block().ops: if op.desc.has_attr("ipu_stage"): ipu_index_list.append(op.desc.attr("ipu_stage")) diff --git a/python/paddle/fluid/tests/unittests/ipu/test_logical_x_op_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_logical_x_op_ipu.py new file mode 100644 index 00000000000..05572a72ea8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/test_logical_x_op_ipu.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +import paddle.static +from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionMode + + +@unittest.skipIf(not paddle.is_compiled_with_ipu(), + "core is not compiled with IPU") +class TestLogicalAnd(IPUOpTest): + def setUp(self): + self.set_atol() + self.set_training() + self.set_test_op() + + @property + def fp16_enabled(self): + return False + + def set_test_op(self): + self.op = paddle.fluid.layers.logical_and + + def set_op_attrs(self): + self.attrs = {} + + def _test_base(self, exec_mode): + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = self.SEED + startup_prog.random_seed = self.SEED + + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data( + name=self.feed_list[0], + shape=self.feed_shape[0], + dtype=self.feed_dtype[0]) + y = paddle.static.data( + name=self.feed_list[1], + shape=self.feed_shape[1], + dtype=self.feed_dtype[1]) + + out = self.op(x, y, **self.attrs) + + fetch_list = [out.name] + + if exec_mode == ExecutionMode.CPU_FP32: + place = paddle.CPUPlace() + else: + place = paddle.IPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + if exec_mode != ExecutionMode.CPU_FP32: + feed_list = self.feed_list + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config(is_training=self.is_training) + if exec_mode == ExecutionMode.IPU_POPART_FP16: + ipu_strategy.set_precision_config(enable_fp16=True) + program = paddle.static.IpuCompiledProgram( + main_prog, + ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) + else: + program = main_prog + + result = exe.run(program, feed=self.feed, fetch_list=fetch_list) + return result[0] + + def run_test_base(self): + output_dict = {} + for mode in ExecutionMode: + if mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled: + break + output_dict[mode] = self._test_base(mode).astype(np.int32) + + self.check(output_dict, check_shape=True) + + def set_feed_attr(self): + self.feed_shape = [x.shape for x in self.feed.values()] + self.feed_list = list(self.feed.keys()) + self.feed_dtype = ['bool', 'bool'] + + def set_data_feed0(self): + x = np.random.choice([True, False], size=(1, 3, 5, 5)) + y = np.random.choice([True, False], size=(1, 3, 5, 5)) + self.feed = { + "x": x.astype('bool'), + "y": y.astype('bool'), + } + self.set_feed_attr() + + def test_case0(self): + self.set_data_feed0() + self.set_op_attrs() + self.run_test_base() + + +class TestLogicalOr(TestLogicalAnd): + def set_test_op(self): + self.op = paddle.fluid.layers.logical_or + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ipu/test_one_hot_op_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_one_hot_op_ipu.py new file mode 100644 index 00000000000..33a5dc888c2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/test_one_hot_op_ipu.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +import paddle.static +from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionMode + + +@unittest.skipIf(not paddle.is_compiled_with_ipu(), + "core is not compiled with IPU") +class TestBase(IPUOpTest): + def setUp(self): + self.set_atol() + self.set_training() + self.set_data_feed() + self.set_feed_attr() + self.set_op_attrs() + + @property + def fp16_enabled(self): + return True + + def set_data_feed(self): + data1 = np.array([[1], [1], [3], [0]]) + + self.feed = {'x': data1.astype(np.int32)} + + def set_feed_attr(self): + self.feed_shape = [x.shape for x in self.feed.values()] + self.feed_list = list(self.feed.keys()) + + def set_op_attrs(self): + self.attrs = {"depth": 4, "allow_out_of_range": False} + + def _test_base(self, exec_mode): + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = self.SEED + startup_prog.random_seed = self.SEED + + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data( + name=self.feed_list[0], + shape=self.feed_shape[0], + dtype='int32') + + out = paddle.fluid.layers.one_hot(x, **self.attrs) + + fetch_list = [out.name] + + if exec_mode == ExecutionMode.CPU_FP32: + place = paddle.CPUPlace() + else: + place = paddle.IPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + if exec_mode != ExecutionMode.CPU_FP32: + feed_list = self.feed_list + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config(is_training=self.is_training) + if exec_mode == ExecutionMode.IPU_POPART_FP16: + ipu_strategy.set_precision_config(enable_fp16=True) + program = paddle.static.IpuCompiledProgram( + main_prog, + ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) + else: + program = main_prog + + feed = self.feed + + result = exe.run(program, feed=feed, fetch_list=fetch_list) + + return result[0] + + def test_base(self): + output_dict = {} + for mode in ExecutionMode: + if (mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled): + break + output_dict[mode] = self._test_base(mode).flatten() + + self.check(output_dict) + + +@unittest.skip('does not support allow_out_of_range=True') +class TestCase1(TestBase): + def set_op_attrs(self): + self.attrs = {"depth": 4, "allow_out_of_range": True} + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ipu/test_one_hot_v2_op_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_one_hot_v2_op_ipu.py new file mode 100644 index 00000000000..79fc9b04e16 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/test_one_hot_v2_op_ipu.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +import paddle.static +from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionMode + + +@unittest.skipIf(not paddle.is_compiled_with_ipu(), + "core is not compiled with IPU") +class TestBase(IPUOpTest): + def setUp(self): + self.set_atol() + self.set_training() + self.set_data_feed() + self.set_feed_attr() + self.set_op_attrs() + + @property + def fp16_enabled(self): + return True + + def set_data_feed(self): + data1 = np.array([[1], [1], [3], [0]]) + + self.feed = {'x': data1.astype(np.int32)} + + def set_feed_attr(self): + self.feed_shape = [x.shape for x in self.feed.values()] + self.feed_list = list(self.feed.keys()) + + def set_op_attrs(self): + self.attrs = {"depth": 4, "allow_out_of_range": False} + + def _test_base(self, exec_mode): + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = self.SEED + startup_prog.random_seed = self.SEED + + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data( + name=self.feed_list[0], + shape=self.feed_shape[0], + dtype='int32') + + out = paddle.fluid.input.one_hot(x, **self.attrs) + + fetch_list = [out.name] + + if exec_mode == ExecutionMode.CPU_FP32: + place = paddle.CPUPlace() + else: + place = paddle.IPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + if exec_mode != ExecutionMode.CPU_FP32: + feed_list = self.feed_list + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config(is_training=self.is_training) + if exec_mode == ExecutionMode.IPU_POPART_FP16: + ipu_strategy.set_precision_config(enable_fp16=True) + program = paddle.static.IpuCompiledProgram( + main_prog, + ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) + else: + program = main_prog + + feed = self.feed + + result = exe.run(program, feed=feed, fetch_list=fetch_list) + + return result[0] + + def test_base(self): + output_dict = {} + for mode in ExecutionMode: + if (mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled): + break + output_dict[mode] = self._test_base(mode).flatten() + + self.check(output_dict) + + +@unittest.skip('does not support allow_out_of_range=True') +class TestCase1(TestBase): + def set_op_attrs(self): + self.attrs = {"depth": 4, "allow_out_of_range": True} + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ipu/test_optimizer_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_optimizer_ipu.py index 1cc10da3d73..bc9d05c4a87 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_optimizer_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_optimizer_ipu.py @@ -91,6 +91,15 @@ class TestBase(IPUOpTest): ipu_strategy = paddle.static.IpuStrategy() ipu_strategy.set_graph_config(is_training=True) ipu_strategy.loss_scaling = self.attrs["loss_scaling"] + if "use_no_bias_optimizer" in self.attrs.keys(): + ipu_strategy.set_options({ + "use_no_bias_optimizer": + self.attrs["use_no_bias_optimizer"] + }) + if "accl1_type" in self.attrs.keys(): + ipu_strategy.set_options({ + "accl1_type": self.attrs["accl1_type"] + }) program = paddle.static.IpuCompiledProgram( main_prog, ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) @@ -141,6 +150,28 @@ class TestAdamCase2(TestBase): } +@unittest.skip('cpu do not support AdamNoBias') +class TestAdamNoBias(TestBase): + def set_attrs(self): + self.attrs = { + "optimizer": 'adam', + "weight_decay": 0.0, + "loss_scaling": 4.0, + "use_no_bias_optimizer": True, + } + + +@unittest.skip('cpu do not support FLOAT16') +class TestAdamCase3(TestBase): + def set_attrs(self): + self.attrs = { + "optimizer": 'adam', + "weight_decay": 0.0, + "loss_scaling": 4.0, + "accl1_type": "FLOAT16", + } + + @unittest.skip('seems cpu output wrong') class TestLambCase1(TestBase): def set_attrs(self): @@ -161,5 +192,27 @@ class TestLamb(TestBase): } +@unittest.skip('cpu do not support LambNoBias') +class TestLambNoBias(TestBase): + def set_attrs(self): + self.attrs = { + "optimizer": 'lamb', + "weight_decay": 0.1, + "loss_scaling": 6.0, + "use_no_bias_optimizer": True + } + + +@unittest.skip('cpu do not support FLOAT16') +class TestLambCase2(TestBase): + def set_attrs(self): + self.attrs = { + "optimizer": 'lamb', + "weight_decay": 0.1, + "loss_scaling": 6.0, + "accl1_type": "FLOAT16" + } + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ipu/test_set_batch_size_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_set_batch_size_ipu.py index 9a18922f353..6702ae4344e 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_set_batch_size_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_set_batch_size_ipu.py @@ -88,11 +88,10 @@ class TestBase(IPUOpTest): if exec_mode != ExecutionMode.CPU_FP32: feed_list = self.feed_list ipu_strategy = paddle.static.IpuStrategy() - ipu_strategy.set_graph_config(is_training=self.is_training) + ipu_strategy.set_graph_config( + is_training=self.is_training, micro_batch_size=2) if exec_mode == ExecutionMode.IPU_POPART_FP16: ipu_strategy.set_precision_config(enable_fp16=True) - # set batch size - ipu_strategy.micro_batch_size = 2 program = paddle.static.IpuCompiledProgram( main_prog, ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) -- GitLab