未验证 提交 4c9f96c9 编写于 作者: A Aurelius84 提交者: GitHub

[CustomOp] Support Compile multi ops at same time (#30920)


* add more unitest for ABI compatibility

* add more unittest

* refine warning style

* support compile multi custom ops in same time

* fix not import paddle in unittest

* fix typo

* add more unittest

* add comment for details
上级 caf9d398
......@@ -27,9 +27,9 @@ foreach(src ${TEST_OPS})
endforeach()
# Compiling .so will cost some time, but running process is very fast.
set_tests_properties(test_custom_op_with_setup PROPERTIES TIMEOUT 180)
set_tests_properties(test_jit_load PROPERTIES TIMEOUT 180)
set_tests_properties(test_setup_install PROPERTIES TIMEOUT 180)
set_tests_properties(test_setup_build PROPERTIES TIMEOUT 180)
set_tests_properties(test_simple_custom_op_setup PROPERTIES TIMEOUT 250)
set_tests_properties(test_simple_custom_op_jit PROPERTIES TIMEOUT 180)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class Relu3Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Y", in_dims);
}
};
class Relu3OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddOutput("Y", "Output of relu_op");
AddComment(R"DOC(
Relu3 Operator.
)DOC");
}
};
class Relu3GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
}
};
template <typename T>
class Relu3GradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType("relu3_grad");
op->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class Relu3Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("X");
auto* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < in_t->numel(); ++i) {
y[i] = std::max(static_cast<T>(0.), x[i]);
}
}
};
template <typename DeviceContext, typename T>
class Relu3GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* y_t = ctx.Input<Tensor>("Y");
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = dy_t->data<T>();
auto y = y_t->data<T>();
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < y_t->numel(); ++i) {
dx[i] = dy[i] * (y[i] > static_cast<T>(0) ? 1. : 0.);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(relu3,
ops::Relu3Op,
ops::Relu3OpMaker,
ops::Relu3GradMaker<paddle::framework::OpDesc>,
ops::Relu3GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(relu3_grad, ops::Relu3GradOp);
REGISTER_OP_CPU_KERNEL(relu3,
ops::Relu3Kernel<CPU, float>,
ops::Relu3Kernel<CPU, double>);
REGISTER_OP_CPU_KERNEL(relu3_grad,
ops::Relu3GradKernel<CPU, float>,
ops::Relu3GradKernel<CPU, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void KeRelu3(const T* x, const int num, T* y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = max(x[i], static_cast<T>(0.));
}
}
template <typename DeviceContext, typename T>
class Relu3CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("X");
auto* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int num = in_t->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeRelu3<T><<<grid, block, 0, dev_ctx.stream()>>>(x, num, y);
}
};
template <typename T>
__global__ void KeRelu3Grad(const T* y, const T* dy, const int num, T* dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
}
}
template <typename DeviceContext, typename T>
class Relu3GradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* y_t = ctx.Input<Tensor>("Y");
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = dy_t->data<T>();
auto y = y_t->data<T>();
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int num = dy_t->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeRelu3Grad<T><<<grid, block, 0, dev_ctx.stream()>>>(y, dy, num, dx);
}
};
} // namespace operators
} // namespace paddle
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(relu3,
paddle::operators::Relu3CUDAKernel<CUDA, float>,
paddle::operators::Relu3CUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(relu3_grad,
paddle::operators::Relu3GradCUDAKernel<CUDA, float>,
paddle::operators::Relu3GradCUDAKernel<CUDA, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x);
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out);
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x);
std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out);
std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape);
std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype);
// Reuse codes in `relu_op_simple.cc/cu` to register another custom operator
// to test jointly compile multi operators at same time.
PD_BUILD_OPERATOR("relu3")
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
.SetBackwardOp("relu3_grad")
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward));
......@@ -27,7 +27,8 @@ setup(
ext_modules=[
CUDAExtension(
name='librelu2_op_from_setup',
sources=['relu_op.cc', 'relu_op.cu'],
sources=['relu_op3.cc', 'relu_op3.cu', 'relu_op.cc',
'relu_op.cu'], # test for multi ops
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args)
],
......
......@@ -25,7 +25,8 @@ setup(
ext_modules=[
CUDAExtension(
name='custom_relu2',
sources=['relu_op.cc', 'relu_op.cu'],
sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc',
'relu_op3.cu'], # test for multi ops
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args)
])
......@@ -22,7 +22,9 @@ setup(
ext_modules=[
CUDAExtension(
name='simple_setup_relu2',
sources=['relu_op_simple.cc', 'relu_op_simple.cu'],
sources=[
'relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc'
], # test for multi ops
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args)
])
......@@ -24,9 +24,9 @@ from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_me
use_new_custom_op_load_method(False)
# Compile and load custom op Just-In-Time.
relu2 = load(
name='relu2',
sources=['relu_op.cc', 'relu_op.cu'],
custom_module = load(
name='custom_relu2',
sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', 'relu_op3.cu'],
interpreter='python', # add for unittest
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args, # add for Coverage CI
......@@ -37,12 +37,14 @@ relu2 = load(
class TestJITLoad(unittest.TestCase):
def test_api(self):
raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32')
gt_data = np.array([[0, 1, 0], [1, 0, 0]]).astype('float32')
x = paddle.to_tensor(raw_data, dtype='float32')
# use custom api
out = relu2(x)
self.assertTrue(
np.array_equal(out.numpy(),
np.array([[0, 1, 0], [1, 0, 0]]).astype('float32')))
out = custom_module.relu2(x)
out3 = custom_module.relu3(x)
self.assertTrue(np.array_equal(out.numpy(), gt_data))
self.assertTrue(np.array_equal(out3.numpy(), gt_data))
if __name__ == '__main__':
......
......@@ -14,8 +14,11 @@
import os
import unittest
import numpy as np
from test_custom_op import CustomOpTest, load_so
import paddle
from paddle.utils.cpp_extension.extension_utils import run_cmd
from paddle.fluid.layer_helper import LayerHelper
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
......@@ -32,6 +35,34 @@ def compile_so():
run_cmd(cmd)
# `setup.py build` only produce .so file containing multi operators.
# Python Interface should be added manually. `relu2` api is in `test_custom_op.py`
def relu3(x, name=None):
helper = LayerHelper("relu3", **locals())
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False)
helper.append_op(type="relu3", inputs={"X": x}, outputs={"Y": out})
return out
class TestCompileMultiOp(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_relu3(self):
raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32')
x = paddle.to_tensor(raw_data, dtype='float32')
# use custom api
out = relu3(x)
self.assertTrue(
np.array_equal(out.numpy(),
np.array([[0, 1, 0], [1, 0, 0]]).astype('float32')))
def tearDown(self):
paddle.enable_static()
if __name__ == '__main__':
compile_so()
load_so(so_name='librelu2_op_from_setup.so')
......
......@@ -51,13 +51,14 @@ class TestSetUpInstall(unittest.TestCase):
import custom_relu2
raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32')
gt_data = np.array([[0, 1, 0], [1, 0, 0]]).astype('float32')
x = paddle.to_tensor(raw_data, dtype='float32')
# use custom api
out = custom_relu2.relu2(x)
out3 = custom_relu2.relu3(x)
self.assertTrue(
np.array_equal(out.numpy(),
np.array([[0, 1, 0], [1, 0, 0]]).astype('float32')))
self.assertTrue(np.array_equal(out.numpy(), gt_data))
self.assertTrue(np.array_equal(out3.numpy(), gt_data))
if __name__ == '__main__':
......
......@@ -21,16 +21,16 @@ from utils import paddle_includes, extra_compile_args
from test_simple_custom_op_setup import relu2_dynamic, relu2_static
# Compile and load custom op Just-In-Time.
simple_relu2 = load(
custom_module = load(
name='simple_jit_relu2',
sources=['relu_op_simple.cc', 'relu_op_simple.cu'],
sources=['relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args) # add for Coverage CI
class TestJITLoad(unittest.TestCase):
def setUp(self):
self.custom_op = simple_relu2
self.custom_ops = [custom_module.relu2, custom_module.relu3]
self.dtypes = ['float32', 'float64']
self.devices = ['cpu', 'gpu']
......@@ -38,24 +38,26 @@ class TestJITLoad(unittest.TestCase):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = relu2_static(self.custom_op, device, dtype, x)
pd_out = relu2_static(self.custom_op, device, dtype, x, False)
for custom_op in self.custom_ops:
out = relu2_static(custom_op, device, dtype, x)
pd_out = relu2_static(custom_op, device, dtype, x, False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
"custom op out: {},\n paddle api out: {}".format(
out, pd_out))
def test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, x_grad = relu2_dynamic(self.custom_op, device, dtype, x)
pd_out, pd_x_grad = relu2_dynamic(self.custom_op, device, dtype,
for custom_op in self.custom_ops:
out, x_grad = relu2_dynamic(custom_op, device, dtype, x)
pd_out, pd_x_grad = relu2_dynamic(custom_op, device, dtype,
x, False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
"custom op out: {},\n paddle api out: {}".format(
out, pd_out))
self.assertTrue(
np.array_equal(x_grad, pd_x_grad),
"custom op x grad: {},\n paddle api x grad: {}".format(
......
......@@ -107,7 +107,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
# usage: import the package directly
import simple_setup_relu2
self.custom_op = simple_setup_relu2.relu2
self.custom_ops = [simple_setup_relu2.relu2, simple_setup_relu2.relu3]
self.dtypes = ['float32', 'float64']
self.devices = ['cpu', 'gpu']
......@@ -116,36 +116,38 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = relu2_static(self.custom_op, device, dtype, x)
pd_out = relu2_static(self.custom_op, device, dtype, x, False)
for custom_op in self.custom_ops:
out = relu2_static(custom_op, device, dtype, x)
pd_out = relu2_static(custom_op, device, dtype, x, False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
"custom op out: {},\n paddle api out: {}".format(
out, pd_out))
def test_static_pe(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = relu2_static_pe(self.custom_op, device, dtype, x)
pd_out = relu2_static_pe(self.custom_op, device, dtype, x,
False)
for custom_op in self.custom_ops:
out = relu2_static_pe(custom_op, device, dtype, x)
pd_out = relu2_static_pe(custom_op, device, dtype, x, False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
"custom op out: {},\n paddle api out: {}".format(
out, pd_out))
def test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, x_grad = relu2_dynamic(self.custom_op, device, dtype, x)
pd_out, pd_x_grad = relu2_dynamic(self.custom_op, device, dtype,
for custom_op in self.custom_ops:
out, x_grad = relu2_dynamic(custom_op, device, dtype, x)
pd_out, pd_x_grad = relu2_dynamic(custom_op, device, dtype,
x, False)
self.assertTrue(
np.array_equal(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
"custom op out: {},\n paddle api out: {}".format(
out, pd_out))
self.assertTrue(
np.array_equal(x_grad, pd_x_grad),
"custom op x grad: {},\n paddle api x grad: {}".format(
......
......@@ -271,23 +271,21 @@ class BuildExtension(build_ext, object):
"""
Record custum op inforomation.
"""
# parse op name
sources = []
for extension in self.extensions:
sources.extend(extension.sources)
sources = [os.path.abspath(s) for s in sources]
op_name = parse_op_name_from(sources)
# parse shared library abs path
outputs = self.get_outputs()
assert len(outputs) == 1
# multi operators built into same one .so file
so_path = os.path.abspath(outputs[0])
so_name = os.path.basename(so_path)
for i, extension in enumerate(self.extensions):
sources = [os.path.abspath(s) for s in extension.sources]
op_names = parse_op_name_from(sources)
build_directory = os.path.abspath(outputs[0])
so_name = os.path.basename(build_directory)
for op_name in op_names:
CustomOpInfo.instance().add(op_name,
so_name=so_name,
build_directory=build_directory)
build_directory=so_path)
class EasyInstallCommand(easy_install, object):
......
......@@ -70,7 +70,7 @@ ABI_INCOMPATIBILITY_WARNING = '''
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Found that your compiler ({user_compiler} == {version}) may be ABI-incompatible with pre-insalled Paddle!
Found that your compiler ({user_compiler} == {version}) may be ABI-incompatible with pre-installed Paddle!
Please use compiler that is ABI-compatible with GCC >= 5.4 (Recommended 8.2).
See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html for ABI Compatibility
......@@ -125,11 +125,12 @@ def custom_write_stub(resource, pyfile):
import types
import paddle
def inject_ext_module(module_name, api_name):
def inject_ext_module(module_name, api_names):
if module_name in sys.modules:
return sys.modules[module_name]
new_module = types.ModuleType(module_name)
for api_name in api_names:
setattr(new_module, api_name, eval(api_name))
return new_module
......@@ -141,9 +142,8 @@ def custom_write_stub(resource, pyfile):
assert os.path.exists(so_path)
# load custom op shared library with abs path
new_custom_op = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path)
assert len(new_custom_op) == 1
m = inject_ext_module(__name__, new_custom_op[0])
new_custom_ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path)
m = inject_ext_module(__name__, new_custom_ops)
__bootstrap__()
......@@ -154,21 +154,25 @@ def custom_write_stub(resource, pyfile):
_, op_info = CustomOpInfo.instance().last()
so_path = op_info.build_directory
new_custom_op = load_op_meta_info_and_register_op(so_path)
assert len(new_custom_op
) == 1, "The number of loaded costom operators is %d" % len(
new_custom_op)
new_custom_ops = load_op_meta_info_and_register_op(so_path)
assert len(
new_custom_ops
) > 0, "Required at least one custom operators, but received len(custom_op) = %d" % len(
new_custom_ops)
# NOTE: To avoid importing .so file instead of python file because they have same name,
# we rename .so shared library to another name, see EasyInstallCommand.
filename, ext = os.path.splitext(resource)
resource = filename + "_pd_" + ext
api_content = []
for op_name in new_custom_ops:
api_content.append(_custom_api_content(op_name))
with open(pyfile, 'w') as f:
f.write(
_stub_template.format(
resource=resource,
custom_api=_custom_api_content(new_custom_op[0])))
resource=resource, custom_api='\n\n'.join(api_content)))
OpInfo = collections.namedtuple('OpInfo',
......@@ -406,11 +410,12 @@ def parse_op_info(op_name):
return in_names, out_infos
def _import_module_from_library(name, build_directory, verbose=False):
def _import_module_from_library(module_name, build_directory, verbose=False):
"""
Load .so shared library and import it as callable python module.
"""
ext_path = os.path.join(build_directory, name + '.so')
# TODO(Aurelius84): Consider file suffix is .dll on Windows Platform.
ext_path = os.path.join(build_directory, module_name + '.so')
if not os.path.exists(ext_path):
raise FileNotFoundError("Extension path: {} does not exist.".format(
ext_path))
......@@ -418,27 +423,30 @@ def _import_module_from_library(name, build_directory, verbose=False):
# load custom op_info and kernels from .so shared library
log_v('loading shared library from: {}'.format(ext_path), verbose)
op_names = load_op_meta_info_and_register_op(ext_path)
assert len(op_names) == 1
# generate Python api in ext_path
return _generate_python_module(op_names[0], build_directory, verbose)
return _generate_python_module(module_name, op_names, build_directory,
verbose)
def _generate_python_module(op_name, build_directory, verbose=False):
def _generate_python_module(module_name,
op_names,
build_directory,
verbose=False):
"""
Automatically generate python file to allow import or load into as module
"""
api_file = os.path.join(build_directory, op_name + '.py')
api_file = os.path.join(build_directory, module_name + '.py')
log_v("generate api file: {}".format(api_file), verbose)
# write into .py file
api_content = _custom_api_content(op_name)
api_content = [_custom_api_content(op_name) for op_name in op_names]
with open(api_file, 'w') as f:
f.write(api_content)
f.write('\n\n'.join(api_content))
# load module
custom_api = _load_module_from_file(op_name, api_file, verbose)
return custom_api
custom_module = _load_module_from_file(api_file, verbose)
return custom_module
def _custom_api_content(op_name):
......@@ -475,7 +483,7 @@ def _custom_api_content(op_name):
return api_content
def _load_module_from_file(op_name, api_file_path, verbose=False):
def _load_module_from_file(api_file_path, verbose=False):
"""
Load module from python file.
"""
......@@ -494,8 +502,7 @@ def _load_module_from_file(op_name, api_file_path, verbose=False):
loader = machinery.SourceFileLoader(ext_name, api_file_path)
module = loader.load_module()
assert hasattr(module, op_name)
return getattr(module, op_name)
return module
def _get_api_inputs_str(op_name):
......@@ -621,11 +628,7 @@ def parse_op_name_from(sources):
content = f.read()
op_names |= regex(content)
# TODO(Aurelius84): Support register more customs op at once
assert len(
op_names) == 1, "The number of registered costom operators is %d" % len(
op_names)
return list(op_names)[0]
return list(op_names)
def run_cmd(command, verbose=False):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册