未验证 提交 2c8c8419 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] add custom-op UTs 1/N (#44329)

* add custom-op UTs 1

* add authors
Co-authored-by: NAllen Guo <alleng@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>

* update url
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
上级 c8e26fea
# Add custom op for Paddle on IPU
## Add custom op in Paddle
reference
https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html
## Write custom op for PopART
reference
https://docs.graphcore.ai/projects/popart-user-guide/en/latest/custom_ops.html
## Register custom op for Paddle on IPU
这里采用即时编译(JIT Compile) 的方法使用 custom op.
### 实现 custom op
根据上面的两个文档, 首先添加 custom op 的实现.
`leaky_relu_cpu.cc` 包含了 Paddle 中 custom op 的定义和 cpu 实现, 这里的实现是和标准的 Paddle 添加 custom op 是完全一致的. 这里的 cpu 实现不是必须的, cpu 实现可以用来检验 ipu 实现的正确性.
`leaky_relu_ipu.cc` 包含了 PopART 中 custom op 的定义和 ipu 实现, 同样的, 这里的实现和标准的 PopART 添加 custom op 是完全一致的.
### 载入 custom op
分别在 Paddle 和 PopART 中实现 custom op 的定义后, 使用 `paddle.utils.cpp_extension.load` 编译源文件并把对应的动态库加载到当前进程中.
```python
cur_dir = os.path.dirname(os.path.realpath(__file__))
custom_ops = load(
name="custom_jit_ops",
sources=[
f"{cur_dir}/leaky_relu_cpu.cc",
f"{cur_dir}/leaky_relu_ipu.cc",
],
# 编译 leaky_relu_ipu.cc 时需要添加此参数
extra_cxx_cflags=['-DONNX_NAMESPACE=onnx'])
```
由于 Paddle 中 op 的定义和 PopART 中存在一些差异, 需要手动映射 custom op
```python
# paddle_op is custom op type in Paddle
# popart_op, domain and version is custom op identifier in PopART
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.add_custom_op(
paddle_op="custom_leaky_relu",
popart_op="LeakyRelu",
domain='custom.ops',
version=1)
```
### 使用 custom op
```python
x = paddle.static.data(
name=self.feed_list[0],
shape=self.feed_shape[0],
dtype=self.feed_dtype[0])
# custom op
out = custom_ops.custom_leaky_relu(x, **self.attrs)
```
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#define CHECK_INPUT(x) \
PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
template <typename data_t>
void leaky_relu_cpu_forward_kernel(const data_t* x_data,
data_t* out_data,
int64_t x_numel,
float alpha) {
// x < 0.0f ? alpha * x : x
for (int i = 0; i < x_numel; ++i) {
if (x_data[i] > static_cast<data_t>(0.)) {
out_data[i] = x_data[i];
} else {
out_data[i] = static_cast<data_t>(alpha) * x_data[i];
}
}
}
template <typename data_t>
void leaky_relu_cpu_backward_kernel(const data_t* grad_out_data,
const data_t* out_data,
data_t* grad_x_data,
int64_t out_numel,
float alpha) {
// (grad * (x < 0.0f ? alpha : 1))
for (int i = 0; i < out_numel; ++i) {
if (out_data[i]<out_data[i]> static_cast<data_t>(0)) {
grad_x_data[i] = static_cast<data_t>(alpha);
} else {
grad_x_data[i] = static_cast<data_t>(1.);
}
}
}
std::vector<paddle::Tensor> LeakyReluCPUForward(const paddle::Tensor& x,
float alpha) {
CHECK_INPUT(x);
auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
PD_DISPATCH_FLOATING_TYPES(x.type(), "relu_cpu_forward_kernel", ([&] {
leaky_relu_cpu_forward_kernel<data_t>(
x.data<data_t>(),
out.mutable_data<data_t>(x.place()),
x.size(),
alpha);
}));
return {out};
}
std::vector<paddle::Tensor> LeakyReluCPUBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out,
float alpha) {
CHECK_INPUT(x);
CHECK_INPUT(out);
CHECK_INPUT(grad_out);
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward_kernel", ([&] {
leaky_relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
out.size(),
alpha);
}));
return {grad_x};
}
std::vector<std::vector<int64_t>> LeakyReluInferShape(
std::vector<int64_t> x_shape) {
return {x_shape};
}
std::vector<paddle::DataType> LeakyReluInferDtype(paddle::DataType x_dtype) {
return {x_dtype};
}
PD_BUILD_OP(custom_leaky_relu)
.Inputs({"X"})
.Outputs({"Out"})
.Attrs({"alpha: float"})
.SetKernelFn(PD_KERNEL(LeakyReluCPUForward))
.SetInferShapeFn(PD_INFER_SHAPE(LeakyReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(LeakyReluInferDtype));
PD_BUILD_GRAD_OP(custom_leaky_relu)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.Attrs({"alpha: float"})
.SetKernelFn(PD_KERNEL(LeakyReluCPUBackward));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <popart/opmanager.hpp>
#include <popart/opserialiser.hpp>
#include <popart/popx/opxmanager.hpp>
#include <popart/shapeinference.hpp>
#include <popops/ElementWise.hpp>
namespace CustomOperators {
const popart::OperatorIdentifier LeakyReluId = {"custom.ops", "LeakyRelu", 1};
} // namespace CustomOperators
namespace CustomGradOperators {
const popart::OperatorIdentifier LeakyReluGradId = {
"custom.ops", "LeakyReluGrad", 1};
} // namespace CustomGradOperators
class LeakyReluOp;
class LeakyReluOpx;
class LeakyReluGradOpx;
class LeakyReluGradOp : public popart::Op {
public:
explicit LeakyReluGradOp(const LeakyReluOp &fwdOp);
std::unique_ptr<popart::Op> clone() const final {
return std::make_unique<LeakyReluGradOp>(*this);
}
void setup() final { outInfo(0) = inInfo(0); };
const std::vector<popart::GradInOutMapper> &gradInputInfo() const;
// The Grad Op has 1 output, which is the gradient of the only input
const std::map<int, int> &gradOutToNonGradIn() const;
bool requiresRandomSeed() const override { return false; }
// an estimate of how valuable sub-graph matching will be
float getSubgraphValue() const final { return getHighSubgraphValue(); }
float getAlpha() const { return alpha; }
// Implementation defined below
void appendAttributes(popart::OpSerialiserBase &os) const override;
// Implementation defined below
void appendOutlineAttributes(popart::OpSerialiserBase &os) const override;
private:
float alpha;
};
class LeakyReluOp : public popart::Op {
public:
LeakyReluOp(const popart::OperatorIdentifier &_opid,
float _alpha,
const popart::Op::Settings &settings_)
: popart::Op(_opid, settings_), alpha(_alpha) {}
std::unique_ptr<Op> clone() const final {
return std::make_unique<LeakyReluOp>(*this);
}
void setup() final { outInfo(0) = inInfo(0); }
void appendAttributes(popart::OpSerialiserBase &os) const override {
Op::appendAttributes(os);
os.appendAttribute("alpha", getAlpha());
}
void appendOutlineAttributes(popart::OpSerialiserBase &os) const override {
Op::appendOutlineAttributes(os);
os.appendAttribute("alpha", getAlpha());
}
std::vector<std::unique_ptr<popart::Op>> getGradOps() {
std::vector<std::unique_ptr<Op>> upops;
upops.emplace_back(new LeakyReluGradOp(*this));
return upops;
}
float getSubgraphValue() const final { return getHighSubgraphValue(); }
bool requiresRandomSeed() const override { return false; }
// Attributes
float getAlpha() const { return alpha; }
private:
float alpha;
};
namespace {
using popart::DataType;
using popart::OpDefinition;
static OpDefinition::DataTypes T = {DataType::FLOAT16, DataType::FLOAT};
static OpDefinition leakyReluOpDef({OpDefinition::Inputs({{"input", T}}),
OpDefinition::Outputs({{"output", T}}),
OpDefinition::Attributes({{"alpha",
{"*"}}})});
static popart::OpCreator<LeakyReluOp> leakyReluOpCreator(
popart::OpDefinitions({{CustomOperators::LeakyReluId, leakyReluOpDef}}),
[](const popart::OpCreatorInfo &info) {
// default alpha is 10**(-2)
float alpha = info.attributes.getAttribute<popart::Attributes::Float>(
"alpha", 1e-2f);
return std::make_unique<LeakyReluOp>(info.opid, alpha, info.settings);
},
true);
} // namespace
static popart::RegisterShapeInferenceFunction leakyReluShapeInfer(
CustomOperators::LeakyReluId,
[](popart::ShapeInferenceContext &ctx // NO_LINT
) { ctx.outInfo(0) = ctx.inInfo(0); });
namespace pe = popops::expr;
class LeakyReluOpx : public popart::popx::Opx {
public:
LeakyReluOpx(popart::Op *op, popart::popx::Devicex *devicex)
: popart::popx::Opx(op, devicex) {
verifyOp<LeakyReluOp>(op, {CustomOperators::LeakyReluId});
}
void grow(poplar::program::Sequence &prog) const final { // NOLINT
popart::logging::ir::trace("start Growing LeakyReluOpx");
auto op = getOp<LeakyReluOp>();
poplar::Tensor input = getInTensor(0);
float alpha = op.getAlpha();
// x < 0.0f ? alpha * x : x
auto expression = pe::Select(pe::Mul(pe::Const(alpha), pe::_1),
pe::_1,
pe::Lt(pe::_1, pe::Const(0.0f)));
popops::mapInPlace(graph(),
expression,
{input},
prog,
debugContext("LeakyRelu"),
poplar::OptionFlags());
setOutTensor(0, input);
}
};
class LeakyReluGradOpx : public popart::popx::Opx {
public:
LeakyReluGradOpx(popart::Op *op, popart::popx::Devicex *devicex)
: popart::popx::Opx(op, devicex) {
verifyOp<LeakyReluGradOp>(op, {CustomGradOperators::LeakyReluGradId});
}
void grow(poplar::program::Sequence &prog) const final { // NOLINT
auto op = getOp<LeakyReluGradOp>();
poplar::Tensor grad = getInTensor(0);
poplar::Tensor input = getInTensor(1);
float alpha = op.getAlpha();
// (grad * (x < 0.0f ? alpha : 1))
pe::Mul expression = pe::Mul(
pe::Select(
pe::Const(alpha), pe::Const(1.0f), pe::Lt(pe::_2, pe::Const(0.0f))),
pe::_1);
auto output = popops::map(graph(),
expression,
{grad, input},
prog,
debugContext("LeakyReluGrad"),
poplar::OptionFlags());
setOutTensor(0, output);
}
};
LeakyReluGradOp::LeakyReluGradOp(const LeakyReluOp &fwdOp)
: popart::Op(CustomGradOperators::LeakyReluGradId, fwdOp.settings),
alpha(fwdOp.getAlpha()) {}
const std::vector<popart::GradInOutMapper> &LeakyReluGradOp::gradInputInfo()
const {
static const std::vector<popart::GradInOutMapper> inInfo = {
{0, 0, popart::GradOpInType::GradOut}, {1, 0, popart::GradOpInType::In}};
return inInfo;
}
// The Grad Op has 1 output, which is the gradient of the only input
const std::map<int, int> &LeakyReluGradOp::gradOutToNonGradIn() const {
static const std::map<int, int> outInfo = {{0, 0}};
return outInfo;
}
void LeakyReluGradOp::appendAttributes(popart::OpSerialiserBase &os) const {
Op::appendAttributes(os);
os.appendAttribute("alpha", getAlpha());
}
void LeakyReluGradOp::appendOutlineAttributes(
popart::OpSerialiserBase &os) const {
Op::appendOutlineAttributes(os);
os.appendAttribute("alpha", getAlpha());
}
static popart::popx::OpxCreator<LeakyReluOpx> LeakyReluOpxCreator(
{CustomOperators::LeakyReluId});
static popart::popx::OpxCreator<LeakyReluGradOpx> LeakyReluGradOpxCreator(
{CustomGradOperators::LeakyReluGradId});
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import sys
import numpy as np
import paddle
import paddle.optimizer
import paddle.static
from paddle.utils.cpp_extension import load
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from op_test_ipu import IPUOpTest, np_dtype_to_fluid_str
def load_custom_ops():
# load custom ops
cur_dir = os.path.dirname(os.path.realpath(__file__))
custom_ops = load(name="custom_jit_ops",
sources=[
f"{cur_dir}/leaky_relu_cpu.cc",
f"{cur_dir}/leaky_relu_ipu.cc",
],
extra_cxx_cflags=['-DONNX_NAMESPACE=onnx'])
return custom_ops
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_feed()
self.set_feed_attr()
self.set_attrs()
def set_feed(self):
self.feed = {
"x": np.random.uniform(low=-2, high=2, size=[3,
5]).astype('float32'),
}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed.values()]
self.feed_list = list(self.feed.keys())
self.feed_dtype = [
np_dtype_to_fluid_str(x.dtype) for x in self.feed.values()
]
def set_attrs(self):
self.attrs = {'alpha': 0.1}
def _test_base(self, run_ipu=True):
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
SEED = self.SEED
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
custom_ops = load_custom_ops()
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype=self.feed_dtype[0])
# custom op
out = custom_ops.custom_leaky_relu(x, **self.attrs)
fetch_list = [out.name]
if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
if run_ipu:
feed_list = self.feed_list
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=False)
# add name mapping for paddle custom op and popart custom ops
# `paddle_op` was defined in leaky_relu_cpu.cc
# `popart_op`, `domain` and `version` was defined in leaky_relu_ipu.cc
ipu_strategy.add_custom_op(paddle_op="custom_leaky_relu",
popart_op="LeakyRelu",
domain='custom.ops',
version=1)
program = paddle.static.IpuCompiledProgram(
main_prog, scope=scope,
ipu_strategy=ipu_strategy).compile(feed_list, fetch_list)
else:
program = main_prog
result = exe.run(program, feed=self.feed, fetch_list=fetch_list)
return result[0]
def test_base(self):
res0 = self._test_base(False)
res1 = self._test_base(True)
self.assertTrue(
np.allclose(res0.flatten(), res1.flatten(), atol=self.atol))
self.assertTrue(res0.shape == res1.shape)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册