未验证 提交 42d58ddd 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] small bug fix (#44473)

* sync misc changes

* add authors
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>

* up x

* Revert "up x"

This reverts commit f3fde458c6cc48613269a643cfe2acf689caccd3.

* add guarg for ipu
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
上级 35ca1ce4
......@@ -300,7 +300,7 @@ void Compiler::RegisterOpFunc() {
#define INT32 std::int32_t
#define BOOL bool
#define STRING std::string
#define STRING_VEC std::vector<std::string*>
#define STRING_VEC std::vector<std::string>
#define NONE
#define ARG(Type, Name) , GetAttrAllowNull<Type>(#Name, op_desc)
......
......@@ -159,8 +159,8 @@ class IpuStrategy {
const std::string &type_str) {
auto it = options.find(key);
PADDLE_ENFORCE_NE(
it,
options.end(),
it == options.end(),
true,
platform::errors::InvalidArgument("Cannot find option: %s, type: %s "
"when setting IpuStrategy options",
key,
......@@ -174,8 +174,8 @@ class IpuStrategy {
std::map<std::string, std::function<ValueType()>> &options) { // NOLINT
auto it = options.find(key);
PADDLE_ENFORCE_NE(
it,
options.end(),
it == options.end(),
true,
platform::errors::InvalidArgument(
"Cannot find option name: %s when trying to get IpuStrategy option",
key));
......
......@@ -285,7 +285,7 @@ Node *binary_cross_entropy_handler(Graph *graph, Node *node) {
reduction = RemoveTailReduction(graph, node, "Out");
}
bool append_identity_loss =
is_dynamic_graph() && IsLastVarNode(GetOutputVarNode("Loss", node));
is_dynamic_graph() && IsLastVarNode(GetOutputVarNode("Out", node));
auto x = GetInputVarNode("X", node);
auto label = GetInputVarNode("Label", node);
......@@ -478,12 +478,12 @@ Node *warpctc_handler(Graph *graph, Node *node) {
auto loss = CreateBaseOp(
graph,
node,
"popart_ctcloss",
"popart_ctcloss_v2",
{log_softmax_logits, cast_label, cast_logits_length, cast_label_length},
append_identity_loss
? std::vector<Node *>{}
: std::vector<Node *>{GetOutputVarNode("Loss", node)},
{{"blank", blank},
{{"blank", int64_t{blank}},
{"reduction", reduction},
{"outDataType", std::string("UNDEFINED")}});
if (append_identity_loss) {
......
......@@ -32,6 +32,39 @@ Node *custom_op_handler(Graph *graph, Node *node) {
Node *print_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto print_output = node->outputs.front();
auto print_input = node->inputs.front();
if (print_output->outputs.size() == 0) {
LOG(WARNING) << "The output of Print OP is not used on IPU graph. Setting "
"the input of Print as Output.";
for (auto &subnode : print_input->outputs) {
if (subnode == node) continue;
ConnectNodes(print_output, subnode);
DisConnectNodes(print_input, subnode);
// replace node_name in op_desc
std::vector<std::string> new_inputs;
auto subnode_inmap = subnode->Op()->Inputs();
for (auto &in_map : subnode_inmap) {
if (std::find(in_map.second.begin(),
in_map.second.end(),
print_input->Name()) != in_map.second.end()) {
std::transform(in_map.second.cbegin(),
in_map.second.cend(),
std::back_inserter(new_inputs),
[&](const std::string &node_name) {
if (node_name == print_input->Name()) {
return print_output->Name();
} else {
return node_name;
}
});
subnode->Op()->SetInput(in_map.first, new_inputs);
subnode->Op()->Flush();
}
}
}
}
auto print_phase = PADDLE_GET_CONST(std::string, op->GetAttr("print_phase"));
int64_t print_gradient = 0;
if (print_phase != "forward") {
......
......@@ -17,11 +17,13 @@
#pragma once
// Ops from AiGraphcoreOpset1
OP_DECL(popart_copyvarupdate_v2, aiGraphcoreOpset.copyvarupdate, NONE) // NOLINT
OP_DECL(popart_groupnormalization_v2, aiGraphcoreOpset.groupnormalization, ARG(INT,num_groups) ARG(FLOAT,epsilon) ) // NOLINT
OP_DECL(popart_subsample_v2, aiGraphcoreOpset.subsample, ARG(INT_VEC,strides) ) // NOLINT
OP_DECL(popart_nop_v2, aiGraphcoreOpset.nop, NONE) // NOLINT
OP_DECL(popart_scale_v2, aiGraphcoreOpset.scale, ARG(FLOAT,scale) ) // NOLINT
OP_DECL(popart_scaledadd_v2, aiGraphcoreOpset.scaledadd, ARG(FLOAT,scale0) ARG(FLOAT,scale1) ) // NOLINT
OP_DECL(popart_lstm_v2, aiGraphcoreOpset.lstm, ARG(INT,outputFullSequence) ) // NOLINT
OP_DECL(popart_gelu_v2, aiGraphcoreOpset.gelu, NONE) // NOLINT
OP_DECL(popart_detach_v2, aiGraphcoreOpset.detach, NONE) // NOLINT
OP_DECL(popart_depthtospace_v2, aiGraphcoreOpset.depthtospace, ARG(INT,blocksize) ARG(STRING,mode) ) // NOLINT
......@@ -32,8 +34,13 @@ OP_DECL(popart_dynamiczero_v2, aiGraphcoreOpset.dynamiczero, ARG(INT_VEC,axes) A
OP_DECL(popart_dynamicadd_v2, aiGraphcoreOpset.dynamicadd, ARG(INT_VEC,axes) ARG(INT_VEC,sizes) ) // NOLINT
OP_DECL(popart_sequenceslice_v2, aiGraphcoreOpset.sequenceslice, ARG(INT,zeroUnused) ) // NOLINT
OP_DECL(popart_replicatedallreduce_v2, aiGraphcoreOpset.replicatedallreduce, OPT_ARG(INT_VEC,commGroup) ) // NOLINT
OP_DECL(popart_l1loss_v2, aiGraphcoreOpset.l1loss, ARG(FLOAT,lambda) SIG_ARG(INT32,popart::ReductionType,reduction) ) // NOLINT
OP_DECL(popart_nllloss_v2, aiGraphcoreOpset.nllloss, SIG_ARG(INT32,popart::ReductionType,reduction) OPT_ARG(INT32,ignoreIndex) ARG(BOOL,inputIsLogProbability) ) // NOLINT
OP_DECL(popart_identityloss_v2, aiGraphcoreOpset.identityloss, SIG_ARG(INT32,popart::ReductionType,reduction) ) // NOLINT
OP_DECL(popart_tensorremap_v2, aiGraphcoreOpset.tensorremap, ARG(INT,remap_type) ) // NOLINT
OP_DECL(popart_ctcloss_v2, aiGraphcoreOpset.ctcloss, SIG_ARG(INT32,popart::ReductionType,reduction) ARG(INT,blank) ARG(STRING,outDataType) ) // NOLINT
OP_DECL(popart__ctcloss_v2, aiGraphcoreOpset._ctcloss, SIG_ARG(INT32,popart::ReductionType,reduction) ARG(INT,blank) ARG(STRING,outDataType) ) // NOLINT
OP_DECL(popart_ctcbeamsearchdecoder_v2, aiGraphcoreOpset.ctcbeamsearchdecoder, ARG(INT,blank) ARG(INT,beamWidth) ARG(INT,topPaths) ) // NOLINT
OP_DECL(popart_ctcloss, aiGraphcoreOpset.ctcloss, SIG_ARG(INT32,popart::ReductionType,reduction) ARG(INT32,blank) ARG(STRING,outDataType) ) // NOLINT
OP_DECL(popart_shapeddropout_v2, aiGraphcoreOpset.shapeddropout, ARG(INT_VEC,shape) ARG(FLOAT,ratio) ) // NOLINT
OP_DECL(popart_atan2_v2, aiGraphcoreOpset.atan2, NONE) // NOLINT
OP_DECL(popart_expm1_v2, aiGraphcoreOpset.expm1, NONE) // NOLINT
......@@ -47,6 +54,9 @@ OP_DECL(popart_bitwiseor_v2, aiGraphcoreOpset.bitwiseor, NONE) // NOLINT
OP_DECL(popart_bitwisexor_v2, aiGraphcoreOpset.bitwisexor, NONE) // NOLINT
OP_DECL(popart_bitwisexnor_v2, aiGraphcoreOpset.bitwisexnor, NONE) // NOLINT
OP_DECL(popart_reducemedian_v2, aiGraphcoreOpset.reducemedian, OPT_ARG(INT_VEC,axes) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_scatterreduce_v2, aiGraphcoreOpset.scatterreduce, ARG(INT,axis_size) ARG(INT,axis) SIG_ARG(INT32,popart::ScatterReduction,reduction) ) // NOLINT
OP_DECL(popart_swish_v2, aiGraphcoreOpset.swish, NONE) // NOLINT
OP_DECL(popart_incrementmod_v2, aiGraphcoreOpset.incrementmod, ARG(FLOAT,increment) ARG(FLOAT,modulus) ) // NOLINT
// Ops from AiOnnxOpset11
OP_DECL(popart_argmax, aiOnnxOpset.argmax, ARG(INT,axis) ARG(INT,keepdims) ) // NOLINT
OP_DECL(popart_argmin, aiOnnxOpset.argmin, ARG(INT,axis) ARG(INT,keepdims) ) // NOLINT
......@@ -117,6 +127,7 @@ OP_DECL(popart_qlinearmatmul, aiOnnxOpset.qlinearmatmul, NONE) // NOLINT
OP_DECL(popart_quantizelinear, aiOnnxOpset.quantizelinear, NONE) // NOLINT
OP_DECL(popart_reversesequence, aiOnnxOpset.reversesequence, ARG(INT,batch_axis) ARG(INT,time_axis) ) // NOLINT
OP_DECL(popart_roialign, aiOnnxOpset.roialign, ARG(STRING,mode) ARG(INT,output_height) ARG(INT,output_width) ARG(INT,sampling_ratio) ARG(FLOAT,spatial_scale) ) // NOLINT
OP_DECL(popart_stringnormalizer, aiOnnxOpset.stringnormalizer, ARG(STRING,case_change_action) ARG(INT,is_case_sensitive) OPT_ARG(STRING,locale) ARG(STRING_VEC,stopwords) ) // NOLINT
OP_DECL(popart_thresholdedrelu, aiOnnxOpset.thresholdedrelu, ARG(FLOAT,alpha) ) // NOLINT
OP_DECL(popart_upsample, aiOnnxOpset.upsample, ARG(STRING,mode) ) // NOLINT
// Ops from AiOnnxOpset9
......@@ -138,6 +149,7 @@ OP_DECL(popart_prelu, aiOnnxOpset.prelu, NONE) // NOLINT
OP_DECL(popart_shrink, aiOnnxOpset.shrink, ARG(FLOAT,bias) ARG(FLOAT,lambd) ) // NOLINT
OP_DECL(popart_sign, aiOnnxOpset.sign, NONE) // NOLINT
OP_DECL(popart_sinh, aiOnnxOpset.sinh, NONE) // NOLINT
OP_DECL(popart_tfidfvectorizer, aiOnnxOpset.tfidfvectorizer, ARG(INT,max_gram_length) ARG(INT,max_skip_count) ARG(INT,min_gram_length) ARG(STRING,mode) ARG(INT_VEC,ngram_counts) ARG(INT_VEC,ngram_indexes) ARG(INT_VEC,pool_int64s) ARG(STRING_VEC,pool_strings) ARG(FLOAT_VEC,weights) ) // NOLINT
OP_DECL(popart_where, aiOnnxOpset.where, NONE) // NOLINT
// Ops from AiOnnxOpset8
OP_DECL(popart_expand, aiOnnxOpset.expand, NONE) // NOLINT
......@@ -153,10 +165,13 @@ OP_DECL(popart_asin, aiOnnxOpset.asin, NONE) // NOLINT
OP_DECL(popart_atan, aiOnnxOpset.atan, NONE) // NOLINT
OP_DECL(popart_cos, aiOnnxOpset.cos, NONE) // NOLINT
OP_DECL(popart_div, aiOnnxOpset.div, NONE) // NOLINT
OP_DECL(popart_gru, aiOnnxOpset.gru, ARG(INT,num_outputs) ARG(FLOAT_VEC,activation_alpha) ARG(FLOAT_VEC,activation_beta) ARG(STRING_VEC,activations) OPT_ARG(FLOAT,clip) ARG(STRING,direction) OPT_ARG(INT,hidden_size) ARG(INT,linear_before_reset) ) // NOLINT
OP_DECL(popart_lstm, aiOnnxOpset.lstm, ARG(INT,num_outputs) ARG(FLOAT_VEC,activation_alpha) ARG(FLOAT_VEC,activation_beta) ARG(STRING_VEC,activations) OPT_ARG(FLOAT,clip) ARG(STRING,direction) OPT_ARG(INT,hidden_size) ARG(INT,input_forget) ) // NOLINT
OP_DECL(popart_mul, aiOnnxOpset.mul, NONE) // NOLINT
OP_DECL(popart_multinomial, aiOnnxOpset.multinomial, ARG(INT,dtype) ARG(INT,sample_size) OPT_ARG(FLOAT,seed) ) // NOLINT
OP_DECL(popart_logical_or, aiOnnxOpset.logical_or, NONE) // NOLINT
OP_DECL(popart_pow, aiOnnxOpset.pow, NONE) // NOLINT
OP_DECL(popart_rnn, aiOnnxOpset.rnn, ARG(INT,num_outputs) ARG(FLOAT_VEC,activation_alpha) ARG(FLOAT_VEC,activation_beta) ARG(STRING_VEC,activations) OPT_ARG(FLOAT,clip) ARG(STRING,direction) OPT_ARG(INT,hidden_size) ) // NOLINT
OP_DECL(popart_sin, aiOnnxOpset.sin, NONE) // NOLINT
OP_DECL(popart_sub, aiOnnxOpset.sub, NONE) // NOLINT
OP_DECL(popart_tan, aiOnnxOpset.tan, NONE) // NOLINT
......
......@@ -16,7 +16,6 @@
#pragma once
OP_DECL(popart_nllloss_v2, aiGraphcoreOpset.nllloss, SIG_ARG(INT32,popart::ReductionType,reduction) OPT_ARG(INT32,ignoreIndex) ARG(BOOL,inputIsLogProbability) ) // NOLINT
OP_DECL(popart_identity_loss, aiGraphcoreOpset.identityloss, SIG_ARG(INT32,popart::ReductionType,reduction) ) // NOLINT
// clang-format on
......@@ -378,6 +378,11 @@ void BindTensor(pybind11::module &m) { // NOLINT
py::arg("tensor"),
py::arg("place"),
py::arg("batch_size") = -1)
.def("_copy_from",
&TensorCopyFrom<paddle::platform::IPUPlace>,
py::arg("tensor"),
py::arg("place"),
py::arg("batch_size") = -1)
.def("_copy_from",
&TensorCopyFrom<paddle::platform::Place>,
py::arg("tensor"),
......
......@@ -1486,7 +1486,11 @@ class Executor(object):
# NOTE(dev): `set` always call TensorCopySync that is a
# blocking behavior. So we use `_copy_from` to replace it.
cpu_tensor = _as_lodtensor(data, core.CPUPlace())
tensor._copy_from(cpu_tensor, self.place)
# for ipu, tensor is allocated on cpu
if core.is_compiled_with_ipu():
tensor._copy_from(cpu_tensor, tensor._place())
else:
tensor._copy_from(cpu_tensor, self.place)
return new_exe.run(scope, list(feed.keys()), fetch_list,
return_numpy)
......
......@@ -17,7 +17,8 @@ import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
from paddle.jit import to_static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest, IPUD2STest
class TestBase(IPUOpTest):
......@@ -106,5 +107,73 @@ class TestCase2(TestBase):
}
class SimpleLayer(paddle.nn.Layer):
def __init__(self):
super(SimpleLayer, self).__init__()
self.conv = paddle.nn.Conv2D(in_channels=3,
out_channels=1,
kernel_size=2,
stride=1)
@to_static()
def forward(self, x, target=None):
x = self.conv(x)
print(x)
x = paddle.fluid.layers.flatten(x, axis=1)
if target is not None:
x = paddle.fluid.layers.softmax(x)
loss = paddle.fluid.layers.cross_entropy(x, target)
loss = paddle.incubate.identity_loss(loss, 1)
return x, loss
return x
class TestD2S(IPUD2STest):
def setUp(self):
self.set_data_feed()
def set_data_feed(self):
self.data = paddle.uniform((8, 3, 10, 10), dtype='float32')
self.label = paddle.randint(0, 10, shape=[8], dtype='int64')
def _test(self, use_ipu=False):
paddle.seed(self.SEED)
np.random.seed(self.SEED)
model = SimpleLayer()
optim = paddle.optimizer.Adam(learning_rate=0.01,
parameters=model.parameters())
if use_ipu:
paddle.set_device('ipu')
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(num_ipus=1,
is_training=True,
micro_batch_size=1,
enable_manual_shard=False)
ipu_strategy.set_optimizer(optim)
result = []
for _ in range(2):
# ipu only needs call model() to do forward/backward/grad_update
pred, loss = model(self.data, self.label)
if not use_ipu:
loss.backward()
optim.step()
optim.clear_grad()
result.append(loss)
if use_ipu:
ipu_strategy.release_patch()
return np.array(result)
def test_training(self):
ipu_loss = self._test(True).flatten()
cpu_loss = self._test(False).flatten()
self.assertTrue(np.allclose(ipu_loss, cpu_loss, atol=1e-4))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册