未验证 提交 a237ff8e 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] support depthwise_conv2d ops (#45234)

* support depthwise_conv2d ops
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>

* fix duplicate name
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
上级 f49f3b4f
......@@ -138,6 +138,10 @@ Node *brelu_handler(Graph *graph, Node *node) {
Node *gelu_handler(Graph *graph, Node *node) {
auto *op = node->Op();
// In case of the Op has no `approximate` attr.
if (!op->HasAttr("approximate")) {
return activation_op_handler(graph, node, "popart_gelu_v2");
}
auto approximate_ = PADDLE_GET_CONST(bool, op->GetAttr("approximate"));
if (approximate_) {
return activation_op_handler(graph, node, "popart_gelu_v2");
......
......@@ -824,6 +824,14 @@ Node *pad_handler(Graph *graph, Node *node) {
{{"mode", mode}});
}
Node *depthwise_conv2d_handler(Graph *graph, Node *node) {
return conv2d_handler(graph, node);
}
Node *depthwise_conv2d_transpose_handler(Graph *graph, Node *node) {
return conv2d_transpose_handler(graph, node);
}
} // namespace
} // namespace ipu
} // namespace platform
......@@ -846,3 +854,6 @@ REGISTER_HANDLER(linear_interp_v2, linear_interp_v2_handler);
REGISTER_HANDLER(trilinear_interp_v2, trilinear_interp_v2_handler);
REGISTER_HANDLER(data_norm, data_norm_handler);
REGISTER_HANDLER(pad3d, pad_handler);
REGISTER_HANDLER(depthwise_conv2d, depthwise_conv2d_handler);
REGISTER_HANDLER(depthwise_conv2d_transpose,
depthwise_conv2d_transpose_handler);
......@@ -150,11 +150,43 @@ class TestCase10(TestBase):
class TestCase11(TestBase):
# Depthwise conv2d transpose
def set_op_attrs(self):
super().set_op_attrs()
self.attrs['groups'] = 3
# depthwise_conv2d_transpose Op
class TestCase12(TestBase):
def set_feed(self):
data = np.random.uniform(size=[1, 3, 10, 10])
weight = np.random.uniform(size=[3, 1, 3, 3])
self.feed_fp32 = {
'in_0': data.astype(np.float32),
'in_1': weight.astype(np.float32)
}
self.feed_fp16 = {
'in_0': data.astype(np.float16),
'in_1': weight.astype(np.float16)
}
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def set_op_attrs(self):
self.attrs = {}
self.attrs['groups'] = 3
@IPUOpTest.static_graph
def build_model(self):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
weight = paddle.static.data(name=self.feed_list[1],
shape=self.feed_shape[1],
dtype='float32')
x = paddle.nn.functional.conv2d_transpose(x, weight, **self.attrs)
self.fetch_list = [x.name]
if __name__ == "__main__":
unittest.main()
......@@ -133,5 +133,38 @@ class TestCase8(TestBase):
self.attrs['padding'] = [1, 2, 2, 3]
# depthwise_conv2d Op
class TestCase9(TestBase):
def set_feed(self):
data = np.random.uniform(size=[1, 3, 10, 10])
weight = np.random.uniform(size=[3, 1, 3, 3])
self.feed_fp32 = {
'in_0': data.astype(np.float32),
'in_1': weight.astype(np.float32)
}
self.feed_fp16 = {
'in_0': data.astype(np.float16),
'in_1': weight.astype(np.float16)
}
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def set_op_attrs(self):
self.attrs = {}
self.attrs['groups'] = 3
@IPUOpTest.static_graph
def build_model(self):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
weight = paddle.static.data(name=self.feed_list[1],
shape=self.feed_shape[1],
dtype='float32')
x = paddle.nn.functional.conv2d(x, weight, **self.attrs)
self.fetch_list = [x.name]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册