From 2c8c8419095c6724691bc7a0be6002c722611eb4 Mon Sep 17 00:00:00 2001 From: Allen Guo Date: Fri, 15 Jul 2022 17:05:21 +0800 Subject: [PATCH] [IPU] add custom-op UTs 1/N (#44329) * add custom-op UTs 1 * add authors Co-authored-by: Allen Guo Co-authored-by: Zhixin Yao Co-authored-by: Zhaorui Chen * update url Co-authored-by: Zhixin Yao Co-authored-by: Zhaorui Chen --- .../tests/unittests/ipu/custom_ops/README.md | 71 ++++++ .../ipu/custom_ops/leaky_relu_cpu.cc | 111 +++++++++ .../ipu/custom_ops/leaky_relu_ipu.cc | 229 ++++++++++++++++++ .../custom_ops/test_custom_leaky_relu_ipu.py | 124 ++++++++++ 4 files changed, 535 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/ipu/custom_ops/README.md create mode 100644 python/paddle/fluid/tests/unittests/ipu/custom_ops/leaky_relu_cpu.cc create mode 100644 python/paddle/fluid/tests/unittests/ipu/custom_ops/leaky_relu_ipu.cc create mode 100644 python/paddle/fluid/tests/unittests/ipu/custom_ops/test_custom_leaky_relu_ipu.py diff --git a/python/paddle/fluid/tests/unittests/ipu/custom_ops/README.md b/python/paddle/fluid/tests/unittests/ipu/custom_ops/README.md new file mode 100644 index 0000000000..efac2a764a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/custom_ops/README.md @@ -0,0 +1,71 @@ +# Add custom op for Paddle on IPU + +## Add custom op in Paddle + +reference + +https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html + +## Write custom op for PopART + +reference + +https://docs.graphcore.ai/projects/popart-user-guide/en/latest/custom_ops.html + +## Register custom op for Paddle on IPU + +这里采用即时编译(JIT Compile) 的方法使用 custom op. + +### 实现 custom op + +根据上面的两个文档, 首先添加 custom op 的实现. + +`leaky_relu_cpu.cc` 包含了 Paddle 中 custom op 的定义和 cpu 实现, 这里的实现是和标准的 Paddle 添加 custom op 是完全一致的. 这里的 cpu 实现不是必须的, cpu 实现可以用来检验 ipu 实现的正确性. + +`leaky_relu_ipu.cc` 包含了 PopART 中 custom op 的定义和 ipu 实现, 同样的, 这里的实现和标准的 PopART 添加 custom op 是完全一致的. + +### 载入 custom op + +分别在 Paddle 和 PopART 中实现 custom op 的定义后, 使用 `paddle.utils.cpp_extension.load` 编译源文件并把对应的动态库加载到当前进程中. + +```python + +cur_dir = os.path.dirname(os.path.realpath(__file__)) +custom_ops = load( + name="custom_jit_ops", + sources=[ + f"{cur_dir}/leaky_relu_cpu.cc", + f"{cur_dir}/leaky_relu_ipu.cc", + ], + # 编译 leaky_relu_ipu.cc 时需要添加此参数 + extra_cxx_cflags=['-DONNX_NAMESPACE=onnx']) + +``` + +由于 Paddle 中 op 的定义和 PopART 中存在一些差异, 需要手动映射 custom op + +```python + +# paddle_op is custom op type in Paddle +# popart_op, domain and version is custom op identifier in PopART +ipu_strategy = paddle.static.IpuStrategy() +ipu_strategy.add_custom_op( + paddle_op="custom_leaky_relu", + popart_op="LeakyRelu", + domain='custom.ops', + version=1) + +``` + +### 使用 custom op + +```python + +x = paddle.static.data( + name=self.feed_list[0], + shape=self.feed_shape[0], + dtype=self.feed_dtype[0]) +# custom op +out = custom_ops.custom_leaky_relu(x, **self.attrs) + +``` diff --git a/python/paddle/fluid/tests/unittests/ipu/custom_ops/leaky_relu_cpu.cc b/python/paddle/fluid/tests/unittests/ipu/custom_ops/leaky_relu_cpu.cc new file mode 100644 index 0000000000..d118aa4380 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/custom_ops/leaky_relu_cpu.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +#define CHECK_INPUT(x) \ + PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") + +template +void leaky_relu_cpu_forward_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel, + float alpha) { + // x < 0.0f ? alpha * x : x + for (int i = 0; i < x_numel; ++i) { + if (x_data[i] > static_cast(0.)) { + out_data[i] = x_data[i]; + } else { + out_data[i] = static_cast(alpha) * x_data[i]; + } + } +} + +template +void leaky_relu_cpu_backward_kernel(const data_t* grad_out_data, + const data_t* out_data, + data_t* grad_x_data, + int64_t out_numel, + float alpha) { + // (grad * (x < 0.0f ? alpha : 1)) + for (int i = 0; i < out_numel; ++i) { + if (out_data[i] static_cast(0)) { + grad_x_data[i] = static_cast(alpha); + } else { + grad_x_data[i] = static_cast(1.); + } + } +} + +std::vector LeakyReluCPUForward(const paddle::Tensor& x, + float alpha) { + CHECK_INPUT(x); + + auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + + PD_DISPATCH_FLOATING_TYPES(x.type(), "relu_cpu_forward_kernel", ([&] { + leaky_relu_cpu_forward_kernel( + x.data(), + out.mutable_data(x.place()), + x.size(), + alpha); + })); + + return {out}; +} + +std::vector LeakyReluCPUBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out, + float alpha) { + CHECK_INPUT(x); + CHECK_INPUT(out); + CHECK_INPUT(grad_out); + + auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward_kernel", ([&] { + leaky_relu_cpu_backward_kernel( + grad_out.data(), + out.data(), + grad_x.mutable_data(x.place()), + out.size(), + alpha); + })); + + return {grad_x}; +} + +std::vector> LeakyReluInferShape( + std::vector x_shape) { + return {x_shape}; +} + +std::vector LeakyReluInferDtype(paddle::DataType x_dtype) { + return {x_dtype}; +} + +PD_BUILD_OP(custom_leaky_relu) + .Inputs({"X"}) + .Outputs({"Out"}) + .Attrs({"alpha: float"}) + .SetKernelFn(PD_KERNEL(LeakyReluCPUForward)) + .SetInferShapeFn(PD_INFER_SHAPE(LeakyReluInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(LeakyReluInferDtype)); + +PD_BUILD_GRAD_OP(custom_leaky_relu) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .Attrs({"alpha: float"}) + .SetKernelFn(PD_KERNEL(LeakyReluCPUBackward)); diff --git a/python/paddle/fluid/tests/unittests/ipu/custom_ops/leaky_relu_ipu.cc b/python/paddle/fluid/tests/unittests/ipu/custom_ops/leaky_relu_ipu.cc new file mode 100644 index 0000000000..1fea75b3b5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/custom_ops/leaky_relu_ipu.cc @@ -0,0 +1,229 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include + +namespace CustomOperators { +const popart::OperatorIdentifier LeakyReluId = {"custom.ops", "LeakyRelu", 1}; +} // namespace CustomOperators +namespace CustomGradOperators { +const popart::OperatorIdentifier LeakyReluGradId = { + "custom.ops", "LeakyReluGrad", 1}; +} // namespace CustomGradOperators + +class LeakyReluOp; +class LeakyReluOpx; +class LeakyReluGradOpx; + +class LeakyReluGradOp : public popart::Op { + public: + explicit LeakyReluGradOp(const LeakyReluOp &fwdOp); + + std::unique_ptr clone() const final { + return std::make_unique(*this); + } + void setup() final { outInfo(0) = inInfo(0); }; + + const std::vector &gradInputInfo() const; + + // The Grad Op has 1 output, which is the gradient of the only input + const std::map &gradOutToNonGradIn() const; + + bool requiresRandomSeed() const override { return false; } + + // an estimate of how valuable sub-graph matching will be + float getSubgraphValue() const final { return getHighSubgraphValue(); } + + float getAlpha() const { return alpha; } + + // Implementation defined below + void appendAttributes(popart::OpSerialiserBase &os) const override; + + // Implementation defined below + void appendOutlineAttributes(popart::OpSerialiserBase &os) const override; + + private: + float alpha; +}; + +class LeakyReluOp : public popart::Op { + public: + LeakyReluOp(const popart::OperatorIdentifier &_opid, + float _alpha, + const popart::Op::Settings &settings_) + : popart::Op(_opid, settings_), alpha(_alpha) {} + + std::unique_ptr clone() const final { + return std::make_unique(*this); + } + + void setup() final { outInfo(0) = inInfo(0); } + + void appendAttributes(popart::OpSerialiserBase &os) const override { + Op::appendAttributes(os); + os.appendAttribute("alpha", getAlpha()); + } + + void appendOutlineAttributes(popart::OpSerialiserBase &os) const override { + Op::appendOutlineAttributes(os); + os.appendAttribute("alpha", getAlpha()); + } + + std::vector> getGradOps() { + std::vector> upops; + upops.emplace_back(new LeakyReluGradOp(*this)); + return upops; + } + + float getSubgraphValue() const final { return getHighSubgraphValue(); } + + bool requiresRandomSeed() const override { return false; } + + // Attributes + float getAlpha() const { return alpha; } + + private: + float alpha; +}; + +namespace { +using popart::DataType; +using popart::OpDefinition; + +static OpDefinition::DataTypes T = {DataType::FLOAT16, DataType::FLOAT}; + +static OpDefinition leakyReluOpDef({OpDefinition::Inputs({{"input", T}}), + OpDefinition::Outputs({{"output", T}}), + OpDefinition::Attributes({{"alpha", + {"*"}}})}); + +static popart::OpCreator leakyReluOpCreator( + popart::OpDefinitions({{CustomOperators::LeakyReluId, leakyReluOpDef}}), + [](const popart::OpCreatorInfo &info) { + // default alpha is 10**(-2) + float alpha = info.attributes.getAttribute( + "alpha", 1e-2f); + return std::make_unique(info.opid, alpha, info.settings); + }, + true); +} // namespace + +static popart::RegisterShapeInferenceFunction leakyReluShapeInfer( + CustomOperators::LeakyReluId, + [](popart::ShapeInferenceContext &ctx // NO_LINT + ) { ctx.outInfo(0) = ctx.inInfo(0); }); + +namespace pe = popops::expr; + +class LeakyReluOpx : public popart::popx::Opx { + public: + LeakyReluOpx(popart::Op *op, popart::popx::Devicex *devicex) + : popart::popx::Opx(op, devicex) { + verifyOp(op, {CustomOperators::LeakyReluId}); + } + + void grow(poplar::program::Sequence &prog) const final { // NOLINT + popart::logging::ir::trace("start Growing LeakyReluOpx"); + + auto op = getOp(); + + poplar::Tensor input = getInTensor(0); + + float alpha = op.getAlpha(); + + // x < 0.0f ? alpha * x : x + auto expression = pe::Select(pe::Mul(pe::Const(alpha), pe::_1), + pe::_1, + pe::Lt(pe::_1, pe::Const(0.0f))); + + popops::mapInPlace(graph(), + expression, + {input}, + prog, + debugContext("LeakyRelu"), + poplar::OptionFlags()); + + setOutTensor(0, input); + } +}; + +class LeakyReluGradOpx : public popart::popx::Opx { + public: + LeakyReluGradOpx(popart::Op *op, popart::popx::Devicex *devicex) + : popart::popx::Opx(op, devicex) { + verifyOp(op, {CustomGradOperators::LeakyReluGradId}); + } + + void grow(poplar::program::Sequence &prog) const final { // NOLINT + auto op = getOp(); + + poplar::Tensor grad = getInTensor(0); + poplar::Tensor input = getInTensor(1); + + float alpha = op.getAlpha(); + + // (grad * (x < 0.0f ? alpha : 1)) + pe::Mul expression = pe::Mul( + pe::Select( + pe::Const(alpha), pe::Const(1.0f), pe::Lt(pe::_2, pe::Const(0.0f))), + pe::_1); + + auto output = popops::map(graph(), + expression, + {grad, input}, + prog, + debugContext("LeakyReluGrad"), + poplar::OptionFlags()); + + setOutTensor(0, output); + } +}; + +LeakyReluGradOp::LeakyReluGradOp(const LeakyReluOp &fwdOp) + : popart::Op(CustomGradOperators::LeakyReluGradId, fwdOp.settings), + alpha(fwdOp.getAlpha()) {} + +const std::vector &LeakyReluGradOp::gradInputInfo() + const { + static const std::vector inInfo = { + {0, 0, popart::GradOpInType::GradOut}, {1, 0, popart::GradOpInType::In}}; + return inInfo; +} + +// The Grad Op has 1 output, which is the gradient of the only input +const std::map &LeakyReluGradOp::gradOutToNonGradIn() const { + static const std::map outInfo = {{0, 0}}; + return outInfo; +} + +void LeakyReluGradOp::appendAttributes(popart::OpSerialiserBase &os) const { + Op::appendAttributes(os); + os.appendAttribute("alpha", getAlpha()); +} + +void LeakyReluGradOp::appendOutlineAttributes( + popart::OpSerialiserBase &os) const { + Op::appendOutlineAttributes(os); + os.appendAttribute("alpha", getAlpha()); +} + +static popart::popx::OpxCreator LeakyReluOpxCreator( + {CustomOperators::LeakyReluId}); +static popart::popx::OpxCreator LeakyReluGradOpxCreator( + {CustomGradOperators::LeakyReluGradId}); diff --git a/python/paddle/fluid/tests/unittests/ipu/custom_ops/test_custom_leaky_relu_ipu.py b/python/paddle/fluid/tests/unittests/ipu/custom_ops/test_custom_leaky_relu_ipu.py new file mode 100644 index 0000000000..fb3fcbf5fe --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/custom_ops/test_custom_leaky_relu_ipu.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import sys + +import numpy as np +import paddle +import paddle.optimizer +import paddle.static +from paddle.utils.cpp_extension import load + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from op_test_ipu import IPUOpTest, np_dtype_to_fluid_str + + +def load_custom_ops(): + # load custom ops + cur_dir = os.path.dirname(os.path.realpath(__file__)) + custom_ops = load(name="custom_jit_ops", + sources=[ + f"{cur_dir}/leaky_relu_cpu.cc", + f"{cur_dir}/leaky_relu_ipu.cc", + ], + extra_cxx_cflags=['-DONNX_NAMESPACE=onnx']) + return custom_ops + + +class TestBase(IPUOpTest): + + def setUp(self): + self.set_atol() + self.set_training() + self.set_feed() + self.set_feed_attr() + self.set_attrs() + + def set_feed(self): + self.feed = { + "x": np.random.uniform(low=-2, high=2, size=[3, + 5]).astype('float32'), + } + + def set_feed_attr(self): + self.feed_shape = [x.shape for x in self.feed.values()] + self.feed_list = list(self.feed.keys()) + self.feed_dtype = [ + np_dtype_to_fluid_str(x.dtype) for x in self.feed.values() + ] + + def set_attrs(self): + self.attrs = {'alpha': 0.1} + + def _test_base(self, run_ipu=True): + scope = paddle.static.Scope() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + SEED = self.SEED + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + custom_ops = load_custom_ops() + + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data(name=self.feed_list[0], + shape=self.feed_shape[0], + dtype=self.feed_dtype[0]) + # custom op + out = custom_ops.custom_leaky_relu(x, **self.attrs) + fetch_list = [out.name] + + if run_ipu: + place = paddle.IPUPlace() + else: + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + if run_ipu: + feed_list = self.feed_list + ipu_strategy = paddle.static.IpuStrategy() + ipu_strategy.set_graph_config(is_training=False) + + # add name mapping for paddle custom op and popart custom ops + # `paddle_op` was defined in leaky_relu_cpu.cc + # `popart_op`, `domain` and `version` was defined in leaky_relu_ipu.cc + ipu_strategy.add_custom_op(paddle_op="custom_leaky_relu", + popart_op="LeakyRelu", + domain='custom.ops', + version=1) + + program = paddle.static.IpuCompiledProgram( + main_prog, scope=scope, + ipu_strategy=ipu_strategy).compile(feed_list, fetch_list) + else: + program = main_prog + + result = exe.run(program, feed=self.feed, fetch_list=fetch_list) + return result[0] + + def test_base(self): + res0 = self._test_base(False) + res1 = self._test_base(True) + + self.assertTrue( + np.allclose(res0.flatten(), res1.flatten(), atol=self.atol)) + + self.assertTrue(res0.shape == res1.shape) + + +if __name__ == "__main__": + unittest.main() -- GitLab