未验证 提交 df3ae18a 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] add more ops (#40691)

* add more ops

* add authors
Co-authored-by: NXiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NAllen Guo <alleng@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>

* rm ipu_strategy.check()

* fix UT fail

* fix typo
Co-authored-by: NXiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>
上级 b8dc673d
......@@ -34,15 +34,36 @@ Node *logical_not_handler(Graph *graph, Node *node) {
{GetOutputVarNode("Out", node)}, {});
}
Node *logical_or_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_logical_or",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{GetOutputVarNode("Out", node)}, {});
}
Node *logical_and_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_logical_and",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{GetOutputVarNode("Out", node)}, {});
}
Node *greater_than_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_greater",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{GetOutputVarNode("Out", node)}, {});
}
Node *less_than_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_less",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{GetOutputVarNode("Out", node)}, {});
}
REGISTER_HANDLER(equal, equal_handler);
REGISTER_HANDLER(logical_not, logical_not_handler);
REGISTER_HANDLER(logical_or, logical_or_handler);
REGISTER_HANDLER(logical_and, logical_and_handler);
REGISTER_HANDLER(greater_than, greater_than_handler);
REGISTER_HANDLER(less_than, less_than_handler);
} // namespace
} // namespace ipu
......
......@@ -98,6 +98,12 @@ Node *matmul_handler(Graph *graph, Node *node) {
if (x_rank == 1) {
perm = std::vector<int64_t>{0};
} else if (x_rank == 2) {
if (!transpose_x && !transpose_y && is_float_equal(alpha, 1.0f)) {
return CreateBaseOp(
graph, node, "popart_matmul",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
node->outputs);
}
return CreateGemm(graph, node,
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
node->outputs, transpose_x, transpose_y, alpha);
......
......@@ -54,10 +54,36 @@ Node *checkpointoutput_handler(Graph *graph, Node *node) {
node->outputs);
}
Node *custom_nll_loss_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction"));
auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignoreIndex"));
auto inputIsLogProbability =
BOOST_GET_CONST(bool, op->GetAttr("inputIsLogProbability"));
return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs,
node->outputs,
{{"reduction", reduction},
{"ignoreIndex", ignoreIndex},
{"inputIsLogProbability", inputIsLogProbability}});
}
Node *identity_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_identity", node->inputs,
node->outputs);
}
Node *detach_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_detach_v2", node->inputs,
node->outputs);
}
REGISTER_HANDLER(custom_op, custom_op_handler);
REGISTER_HANDLER(print, print_handler);
REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler);
REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler);
REGISTER_HANDLER(custom_nll_loss, custom_nll_loss_handler);
REGISTER_HANDLER(identity, identity_handler);
REGISTER_HANDLER(detach, detach_handler);
} // namespace
} // namespace ipu
......
......@@ -49,6 +49,9 @@ Node *fill_constant_handler(Graph *graph, Node *node) {
case framework::proto::VarType::INT64:
value = std::vector<int64_t>(size, value_);
break;
case framework::proto::VarType::BOOL:
value = std::vector<bool>(size, value_);
break;
default:
PADDLE_THROW(
platform::errors::Unimplemented("fill_constant dtype: %d", dtype_));
......@@ -417,6 +420,45 @@ Node *assign_handler(Graph *graph, Node *node) {
{GetOutputVarNode("Out", node)}, {});
}
Node *assign_value_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto dtype = VarType2OnnxDtype(dtype_);
auto dims_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("shape"));
std::vector<int64_t> dims(dims_.begin(), dims_.end());
Attribute values;
std::string value_name;
switch (dtype_) {
case framework::proto::VarType::BOOL: {
value_name = "bool_values";
auto vec_int = BOOST_GET_CONST(std::vector<int>, op->GetAttr(value_name));
std::vector<bool> vec_bool(vec_int.begin(), vec_int.end());
values = vec_bool;
} break;
case framework::proto::VarType::INT32:
value_name = "int32_values";
values = BOOST_GET_CONST(std::vector<int>, op->GetAttr(value_name));
break;
case framework::proto::VarType::FP32:
value_name = "fp32_values";
values = BOOST_GET_CONST(std::vector<float>, op->GetAttr(value_name));
break;
case framework::proto::VarType::INT64:
value_name = "int64_values";
values = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr(value_name));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type(code %d) for AssignValue operator, only "
"supports bool, int32, float32 and int64.",
dtype));
}
return CreateConst(graph, node, node->inputs, node->outputs,
AttributeMap{
{"value", values}, {"dims", dims}, {"dtype", dtype},
});
}
Node *fill_any_like_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto value = BOOST_GET_CONST(float, op->GetAttr("value"));
......@@ -482,6 +524,41 @@ Node *one_hot_handler(Graph *graph, Node *node) {
}
}
Node *one_hot_v2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto depth = BOOST_GET_CONST(int, op->GetAttr("depth"));
auto allow_out_of_range =
BOOST_GET_CONST(bool, op->GetAttr("allow_out_of_range"));
if (allow_out_of_range) {
PADDLE_THROW(platform::errors::Unimplemented(
"Do not support allow_out_of_range=True"));
} else {
auto depth_tensor =
CreateConst(graph, node, {}, {}, {{"value", std::vector<int>{depth}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT32}});
Node *value_tensor = nullptr;
if (GetOutputVarNode("Out", node)->Var()->GetDataType() ==
framework::proto::VarType::FP16) {
value_tensor =
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{0, 1}},
{"dims", std::vector<int64_t>{2}},
{"dtype", ONNXDataType::FLOAT16}});
} else {
value_tensor =
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{0, 1}},
{"dims", std::vector<int64_t>{2}},
{"dtype", ONNXDataType::FLOAT}});
}
return CreateBaseOp(graph, node, "popart_onehot",
{GetInputVarNode("X", node), depth_tensor->outputs[0],
value_tensor->outputs[0]},
{GetOutputVarNode("Out", node)},
{{"axis", int64_t{-1}}});
}
}
Node *split_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
......@@ -510,10 +587,12 @@ REGISTER_HANDLER(shape, shape_handler);
REGISTER_HANDLER(slice, slice_handler);
REGISTER_HANDLER(expand, expand_handler);
REGISTER_HANDLER(assign, assign_handler);
REGISTER_HANDLER(assign_value, assign_value_handler);
REGISTER_HANDLER(fill_any_like, fill_any_like_handler);
REGISTER_HANDLER(lookup_table_v2, lookup_table_v2_handler);
REGISTER_HANDLER(split, split_handler);
REGISTER_HANDLER(one_hot, one_hot_handler);
REGISTER_HANDLER(one_hot_v2, one_hot_v2_handler);
} // namespace
} // namespace ipu
......
......@@ -542,7 +542,7 @@ class IpuStrategy(object):
def set_graph_config(self,
num_ipus=1,
is_training=True,
batch_size=1,
micro_batch_size=1,
enable_manual_shard=False):
"""
Set graph configuration to the IpuStrategy instance.
......@@ -571,7 +571,7 @@ class IpuStrategy(object):
ipu_strategy = static.IpuStrategy()
ipu_strategy.set_graph_config(num_ipus=1,
is_training=True,
batch_size=1,
micro_batch_size=1,
enable_manual_shard=False)
"""
if num_ipus == 1 and enable_manual_shard:
......@@ -581,7 +581,7 @@ class IpuStrategy(object):
options = {
'num_ipus': num_ipus,
'is_training': is_training,
'micro_batch_size': batch_size,
'micro_batch_size': micro_batch_size,
'enable_manual_shard': enable_manual_shard,
}
self.set_options(options)
......@@ -589,6 +589,7 @@ class IpuStrategy(object):
def set_pipelining_config(self,
enable_pipelining=False,
batches_per_step=1,
enable_gradient_accumulation=False,
accumulation_factor=1):
"""
Set pipelining configuration to the IpuStrategy instance. Used to optimize the throughput performance.
......@@ -598,6 +599,8 @@ class IpuStrategy(object):
Default False, which means disabled.
batches_per_step (int, optional): Set the batches per run in data pipelining mode. Only if enable_pipelining=True, batches_per_step is able to be set > 1.
Default 1, which means no data pipelining.
enable_gradient_accumulation (bool, optional): Enable to accumulate gradients before updating the weights in training mode. Only if enable_pipelining=True,
enable_gradient_accumulation is able to be set True. Default False, which means no gradient accumulation.
accumulation_factor (int, optional): Specify the number of micro-batches to accumulate
before applying the varUpdate. Default 1, which means disable the accumulation.
......@@ -617,6 +620,7 @@ class IpuStrategy(object):
ipu_strategy = static.IpuStrategy()
ipu_strategy.set_pipelining_config(enable_pipelining=False,
batches_per_step=1,
enable_gradient_accumulation=False,
accumulation_factor=1)
"""
enable_manual_shard = self.get_option('enable_manual_shard')
......@@ -627,6 +631,7 @@ class IpuStrategy(object):
options = {
'enable_pipelining': enable_pipelining,
'batches_per_step': batches_per_step,
'enable_gradient_accumulation': enable_gradient_accumulation,
'accumulation_factor': accumulation_factor,
}
self.set_options(options)
......@@ -754,6 +759,56 @@ class IpuStrategy(object):
"""
return self._ipu_strategy.get_option(option)['value']
def enable_pattern(self, pattern):
"""
Enable PopART pattern to optimize the graph.
Args:
pattern(string): the name of the pattern.
Returns:
None.
Examples:
.. code-block:: python
# required: ipu
import paddle
import paddle.static as static
paddle.enable_static()
ipu_strategy = static.IpuStrategy()
ipu_strategy.enable_pattern("ViewSimplifyPattern")
"""
self._ipu_strategy.enable_pattern(pattern)
def disable_pattern(self, pattern):
"""
Disable PopART pattern.
Args:
pattern(string): the name of the pattern.
Returns:
None.
Examples:
.. code-block:: python
# required: ipu
import paddle
import paddle.static as static
paddle.enable_static()
ipu_strategy = static.IpuStrategy()
ipu_strategy.disable_pattern("ViewSimplifyPattern")
"""
self._ipu_strategy.disable_pattern(pattern)
@property
def num_ipus(self):
"""
......@@ -817,8 +872,8 @@ class IpuCompiledProgram(object):
main_prog = static.default_main_program()
ipu_strategy = static.IpuStrategy()
ipu_strategy.set_graph_config(num_ipus=1, is_training=True, batch_size=1)
ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, accumulation_factor=1)
ipu_strategy.set_graph_config(num_ipus=1, is_training=True, micro_batch_size=1)
ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, enable_gradient_accumulation=False, accumulation_factor=1)
ipu_strategy.set_precision_config(enable_fp16=False)
ipu_compiled_program = static.IpuCompiledProgram(
......@@ -891,8 +946,8 @@ class IpuCompiledProgram(object):
main_prog = static.default_main_program()
ipu_strategy = static.IpuStrategy()
ipu_strategy.set_graph_config(num_ipus=1, is_training=True, batch_size=1)
ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, accumulation_factor=1)
ipu_strategy.set_graph_config(num_ipus=1, is_training=True, micro_batch_size=1)
ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, enable_gradient_accumulation=False, accumulation_factor=1)
ipu_strategy.set_precision_config(enable_fp16=False)
program = static.IpuCompiledProgram(
......
......@@ -98,5 +98,117 @@ class TestBase(IPUOpTest):
self.check(output_dict)
class TestAssignFp32Value(TestBase):
def set_data_feed(self):
data = np.random.uniform(size=[2, 3, 1])
self.feed_fp32 = {'in_0': data.astype(np.float32)}
self.feed_fp16 = {'in_0': data.astype(np.float16)}
data = np.random.uniform(size=[2, 3, 1])
self.assign_fp32 = data.astype(np.float32)
def _test_base(self, exec_mode):
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(
name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
assign = paddle.assign(self.assign_fp32)
out = paddle.fluid.layers.elementwise_add(x, assign)
fetch_list = [out.name]
if exec_mode == ExecutionMode.CPU_FP32:
place = paddle.CPUPlace()
else:
place = paddle.IPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
if exec_mode != ExecutionMode.CPU_FP32:
feed_list = self.feed_list
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=self.is_training)
if exec_mode == ExecutionMode.IPU_POPART_FP16:
ipu_strategy.set_precision_config(enable_fp16=True)
program = paddle.static.IpuCompiledProgram(
main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
else:
program = main_prog
feed = self.feed_fp32
if exec_mode > ExecutionMode.IPU_FP32:
feed = self.feed_fp16
result = exe.run(program, feed=feed, fetch_list=fetch_list)
return result[0]
class TestAssignBoolValue(TestBase):
def set_data_feed(self):
data = np.random.uniform(size=[2, 3, 1])
self.feed_fp32 = {'in_0': data.astype(np.float32)}
self.feed_fp16 = {'in_0': data.astype(np.float16)}
data = np.random.choice([True, False], size=(2, 3, 1))
self.assign_bool = data.astype(np.bool)
def _test_base(self, exec_mode):
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(
name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
x = paddle.less_than(x, x)
assign = paddle.assign(self.assign_bool)
out = paddle.logical_and(x, assign)
out = paddle.cast(out, 'float32')
fetch_list = [out.name]
if exec_mode == ExecutionMode.CPU_FP32:
place = paddle.CPUPlace()
else:
place = paddle.IPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
if exec_mode != ExecutionMode.CPU_FP32:
feed_list = self.feed_list
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=self.is_training)
if exec_mode == ExecutionMode.IPU_POPART_FP16:
ipu_strategy.set_precision_config(enable_fp16=True)
program = paddle.static.IpuCompiledProgram(
main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
else:
program = main_prog
feed = self.feed_fp32
if exec_mode > ExecutionMode.IPU_FP32:
feed = self.feed_fp16
result = exe.run(program, feed=feed, fetch_list=fetch_list)
return result[0]
if __name__ == "__main__":
unittest.main()
......@@ -22,33 +22,18 @@ from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionMod
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
class TestGreaterThan(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
self.set_test_op()
@property
def fp16_enabled(self):
return True
def set_data_feed(self):
x = np.random.randn(3, 4, 5)
y = np.random.randn(3, 4, 5)
self.feed_fp32 = {
"x": x.astype(np.float32),
"y": y.astype(np.float32),
}
self.feed_fp16 = {
"x": x.astype(np.float16),
"y": y.astype(np.float16),
}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def set_test_op(self):
self.op = paddle.fluid.layers.greater_than
def set_op_attrs(self):
self.attrs = {}
......@@ -71,7 +56,7 @@ class TestBase(IPUOpTest):
shape=self.feed_shape[1],
dtype='float32')
out = paddle.fluid.layers.greater_than(x, y, **self.attrs)
out = self.op(x, y, **self.attrs)
fetch_list = [out.name]
......@@ -102,7 +87,7 @@ class TestBase(IPUOpTest):
result = exe.run(program, feed=feed, fetch_list=fetch_list)
return result[0]
def test(self):
def run_test_base(self):
output_dict = {}
for mode in ExecutionMode:
if mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled:
......@@ -111,29 +96,73 @@ class TestBase(IPUOpTest):
self.check(output_dict)
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def set_data_feed0(self):
x = np.random.randn(3, 4, 5)
y = np.random.randn(3, 4, 5)
self.feed_fp32 = {
"x": x.astype(np.float32),
"y": y.astype(np.float32),
}
self.feed_fp16 = {
"x": x.astype(np.float16),
"y": y.astype(np.float16),
}
self.set_feed_attr()
class TestCase1(TestBase):
def set_data_feed(self):
def set_data_feed1(self):
x = np.ones([1, 10])
y = np.ones([10])
self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}
self.set_feed_attr()
class TestCase2(TestBase):
def set_data_feed(self):
def set_data_feed2(self):
x = np.ones([1, 10])
y = np.zeros([1, 10])
self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}
self.set_feed_attr()
class TestCase3(TestBase):
def set_data_feed(self):
def set_data_feed3(self):
x = np.zeros([1, 10])
y = np.ones([1, 10])
self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}
self.set_feed_attr()
def test_case0(self):
self.set_data_feed0()
self.set_op_attrs()
self.run_test_base()
def test_case1(self):
self.set_data_feed1()
self.set_op_attrs()
self.run_test_base()
def test_case2(self):
self.set_data_feed2()
self.set_op_attrs()
self.run_test_base()
def test_case3(self):
self.set_data_feed3()
self.set_op_attrs()
self.run_test_base()
class TestLessThan(TestGreaterThan):
def set_test_op(self):
self.op = paddle.fluid.layers.less_than
class TestEqual(TestGreaterThan):
def set_test_op(self):
self.op = paddle.fluid.layers.equal
if __name__ == "__main__":
......
......@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import unittest
import paddle
import paddle.static
paddle.enable_static()
......@@ -26,30 +26,31 @@ paddle.enable_static()
class TestIpuShard(unittest.TestCase):
def _test(self):
# build graph
a = paddle.static.data(name='data', shape=[None, 1], dtype='int32')
b = a + 2 # scale : scale * x + bias, ipu_index : no
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
a = paddle.static.data(name='data', shape=[None, 1], dtype='int32')
b = a + 2 # scale : scale * x + bias, ipu_index : no
with paddle.static.ipu_shard_guard(index=1):
c = b + 1 # scale, ipu_index : 1
with paddle.static.ipu_shard_guard(index=2):
d = c * 2 # scale, ipu_index : 2
with paddle.static.ipu_shard_guard(index=3):
e = d + 3 # scale, ipu_index : 3
with paddle.static.ipu_shard_guard(index=1):
e = e + 3 # scale, ipu_index : 1
with paddle.static.ipu_shard_guard(index=2):
e = e + 3 # scale, ipu_index : 2
with paddle.static.ipu_shard_guard(index=1):
f = paddle.tensor.pow(e, 2.0) # pow, ipu_index : 1
with paddle.static.ipu_shard_guard(index=1):
c = b + 1 # scale, ipu_index : 1
with paddle.static.ipu_shard_guard(index=2):
d = c * 2 # scale, ipu_index : 2
with paddle.static.ipu_shard_guard(index=3):
e = d + 3 # scale, ipu_index : 3
with paddle.static.ipu_shard_guard(index=1):
e = e + 3 # scale, ipu_index : 1
with paddle.static.ipu_shard_guard(index=2):
e = e + 3 # scale, ipu_index : 2
with paddle.static.ipu_shard_guard(index=1):
f = paddle.tensor.pow(e, 2.0) # pow, ipu_index : 1
g = f - 1 # scale, ipu_index : 2
with paddle.static.ipu_shard_guard(index=2):
g = f - 1 # scale, ipu_index : 2
h = g + 1 # scale, ipu_index : no
h = g + 1 # scale, ipu_index : no
ipu_index_list = []
main_prog = paddle.static.default_main_program()
for op in main_prog.global_block().ops:
if op.desc.has_attr("ipu_index"):
ipu_index_list.append(op.desc.attr("ipu_index"))
......@@ -69,30 +70,31 @@ class TestIpuShard(unittest.TestCase):
class TestIpuPipeline(unittest.TestCase):
def _test(self):
# build graph
a = paddle.static.data(name='data', shape=[None, 1], dtype='int32')
b = a + 2 # scale : scale * x + bias, ipu_stage : no
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog):
a = paddle.static.data(name='data', shape=[None, 1], dtype='int32')
b = a + 2 # scale : scale * x + bias, ipu_stage : no
with paddle.static.ipu_shard_guard(stage=1):
c = b + 1 # scale, ipu_stage : 1
with paddle.static.ipu_shard_guard(stage=2):
d = c * 2 # scale, ipu_stage : 2
with paddle.static.ipu_shard_guard(stage=3):
e = d + 3 # scale, ipu_stage : 3
with paddle.static.ipu_shard_guard(stage=1):
e = e + 3 # scale, ipu_stage : 1
with paddle.static.ipu_shard_guard(stage=2):
e = e + 3 # scale, ipu_stage : 2
with paddle.static.ipu_shard_guard(stage=1):
f = paddle.tensor.pow(e, 2.0) # pow, ipu_stage : 1
with paddle.static.ipu_shard_guard(stage=1):
c = b + 1 # scale, ipu_stage : 1
with paddle.static.ipu_shard_guard(stage=2):
d = c * 2 # scale, ipu_stage : 2
with paddle.static.ipu_shard_guard(stage=3):
e = d + 3 # scale, ipu_stage : 3
with paddle.static.ipu_shard_guard(stage=1):
e = e + 3 # scale, ipu_stage : 1
with paddle.static.ipu_shard_guard(stage=2):
e = e + 3 # scale, ipu_stage : 2
with paddle.static.ipu_shard_guard(stage=1):
f = paddle.tensor.pow(e, 2.0) # pow, ipu_stage : 1
with paddle.static.ipu_shard_guard(stage=2):
g = f - 1 # scale, ipu_stage : 2
g = f - 1 # scale, ipu_stage : 2
h = g + 1 # scale, ipu_stage : no
h = g + 1 # scale, ipu_stage : no
ipu_index_list = []
main_prog = paddle.static.default_main_program()
for op in main_prog.global_block().ops:
if op.desc.has_attr("ipu_stage"):
ipu_index_list.append(op.desc.attr("ipu_stage"))
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionMode
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestLogicalAnd(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_test_op()
@property
def fp16_enabled(self):
return False
def set_test_op(self):
self.op = paddle.fluid.layers.logical_and
def set_op_attrs(self):
self.attrs = {}
def _test_base(self, exec_mode):
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(
name=self.feed_list[0],
shape=self.feed_shape[0],
dtype=self.feed_dtype[0])
y = paddle.static.data(
name=self.feed_list[1],
shape=self.feed_shape[1],
dtype=self.feed_dtype[1])
out = self.op(x, y, **self.attrs)
fetch_list = [out.name]
if exec_mode == ExecutionMode.CPU_FP32:
place = paddle.CPUPlace()
else:
place = paddle.IPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
if exec_mode != ExecutionMode.CPU_FP32:
feed_list = self.feed_list
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=self.is_training)
if exec_mode == ExecutionMode.IPU_POPART_FP16:
ipu_strategy.set_precision_config(enable_fp16=True)
program = paddle.static.IpuCompiledProgram(
main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
else:
program = main_prog
result = exe.run(program, feed=self.feed, fetch_list=fetch_list)
return result[0]
def run_test_base(self):
output_dict = {}
for mode in ExecutionMode:
if mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled:
break
output_dict[mode] = self._test_base(mode).astype(np.int32)
self.check(output_dict, check_shape=True)
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed.values()]
self.feed_list = list(self.feed.keys())
self.feed_dtype = ['bool', 'bool']
def set_data_feed0(self):
x = np.random.choice([True, False], size=(1, 3, 5, 5))
y = np.random.choice([True, False], size=(1, 3, 5, 5))
self.feed = {
"x": x.astype('bool'),
"y": y.astype('bool'),
}
self.set_feed_attr()
def test_case0(self):
self.set_data_feed0()
self.set_op_attrs()
self.run_test_base()
class TestLogicalOr(TestLogicalAnd):
def set_test_op(self):
self.op = paddle.fluid.layers.logical_or
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionMode
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
@property
def fp16_enabled(self):
return True
def set_data_feed(self):
data1 = np.array([[1], [1], [3], [0]])
self.feed = {'x': data1.astype(np.int32)}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed.values()]
self.feed_list = list(self.feed.keys())
def set_op_attrs(self):
self.attrs = {"depth": 4, "allow_out_of_range": False}
def _test_base(self, exec_mode):
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(
name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='int32')
out = paddle.fluid.layers.one_hot(x, **self.attrs)
fetch_list = [out.name]
if exec_mode == ExecutionMode.CPU_FP32:
place = paddle.CPUPlace()
else:
place = paddle.IPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
if exec_mode != ExecutionMode.CPU_FP32:
feed_list = self.feed_list
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=self.is_training)
if exec_mode == ExecutionMode.IPU_POPART_FP16:
ipu_strategy.set_precision_config(enable_fp16=True)
program = paddle.static.IpuCompiledProgram(
main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
else:
program = main_prog
feed = self.feed
result = exe.run(program, feed=feed, fetch_list=fetch_list)
return result[0]
def test_base(self):
output_dict = {}
for mode in ExecutionMode:
if (mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled):
break
output_dict[mode] = self._test_base(mode).flatten()
self.check(output_dict)
@unittest.skip('does not support allow_out_of_range=True')
class TestCase1(TestBase):
def set_op_attrs(self):
self.attrs = {"depth": 4, "allow_out_of_range": True}
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, ExecutionMode
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
@property
def fp16_enabled(self):
return True
def set_data_feed(self):
data1 = np.array([[1], [1], [3], [0]])
self.feed = {'x': data1.astype(np.int32)}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed.values()]
self.feed_list = list(self.feed.keys())
def set_op_attrs(self):
self.attrs = {"depth": 4, "allow_out_of_range": False}
def _test_base(self, exec_mode):
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = self.SEED
startup_prog.random_seed = self.SEED
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(
name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='int32')
out = paddle.fluid.input.one_hot(x, **self.attrs)
fetch_list = [out.name]
if exec_mode == ExecutionMode.CPU_FP32:
place = paddle.CPUPlace()
else:
place = paddle.IPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
if exec_mode != ExecutionMode.CPU_FP32:
feed_list = self.feed_list
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=self.is_training)
if exec_mode == ExecutionMode.IPU_POPART_FP16:
ipu_strategy.set_precision_config(enable_fp16=True)
program = paddle.static.IpuCompiledProgram(
main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
else:
program = main_prog
feed = self.feed
result = exe.run(program, feed=feed, fetch_list=fetch_list)
return result[0]
def test_base(self):
output_dict = {}
for mode in ExecutionMode:
if (mode > ExecutionMode.IPU_FP32 and not self.fp16_enabled):
break
output_dict[mode] = self._test_base(mode).flatten()
self.check(output_dict)
@unittest.skip('does not support allow_out_of_range=True')
class TestCase1(TestBase):
def set_op_attrs(self):
self.attrs = {"depth": 4, "allow_out_of_range": True}
if __name__ == "__main__":
unittest.main()
......@@ -91,6 +91,15 @@ class TestBase(IPUOpTest):
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=True)
ipu_strategy.loss_scaling = self.attrs["loss_scaling"]
if "use_no_bias_optimizer" in self.attrs.keys():
ipu_strategy.set_options({
"use_no_bias_optimizer":
self.attrs["use_no_bias_optimizer"]
})
if "accl1_type" in self.attrs.keys():
ipu_strategy.set_options({
"accl1_type": self.attrs["accl1_type"]
})
program = paddle.static.IpuCompiledProgram(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
......@@ -141,6 +150,28 @@ class TestAdamCase2(TestBase):
}
@unittest.skip('cpu do not support AdamNoBias')
class TestAdamNoBias(TestBase):
def set_attrs(self):
self.attrs = {
"optimizer": 'adam',
"weight_decay": 0.0,
"loss_scaling": 4.0,
"use_no_bias_optimizer": True,
}
@unittest.skip('cpu do not support FLOAT16')
class TestAdamCase3(TestBase):
def set_attrs(self):
self.attrs = {
"optimizer": 'adam',
"weight_decay": 0.0,
"loss_scaling": 4.0,
"accl1_type": "FLOAT16",
}
@unittest.skip('seems cpu output wrong')
class TestLambCase1(TestBase):
def set_attrs(self):
......@@ -161,5 +192,27 @@ class TestLamb(TestBase):
}
@unittest.skip('cpu do not support LambNoBias')
class TestLambNoBias(TestBase):
def set_attrs(self):
self.attrs = {
"optimizer": 'lamb',
"weight_decay": 0.1,
"loss_scaling": 6.0,
"use_no_bias_optimizer": True
}
@unittest.skip('cpu do not support FLOAT16')
class TestLambCase2(TestBase):
def set_attrs(self):
self.attrs = {
"optimizer": 'lamb',
"weight_decay": 0.1,
"loss_scaling": 6.0,
"accl1_type": "FLOAT16"
}
if __name__ == "__main__":
unittest.main()
......@@ -88,11 +88,10 @@ class TestBase(IPUOpTest):
if exec_mode != ExecutionMode.CPU_FP32:
feed_list = self.feed_list
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=self.is_training)
ipu_strategy.set_graph_config(
is_training=self.is_training, micro_batch_size=2)
if exec_mode == ExecutionMode.IPU_POPART_FP16:
ipu_strategy.set_precision_config(enable_fp16=True)
# set batch size
ipu_strategy.micro_batch_size = 2
program = paddle.static.IpuCompiledProgram(
main_prog,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册