未验证 提交 75f81233 编写于 作者: C Chen Weihang 提交者: GitHub

fix regex error & simplify marco name (#31031)

上级 f0ee1592
......@@ -290,12 +290,12 @@ class OpMetaInfoBuilder {
/////////////////////// Op register API /////////////////////////
// For inference: compile directly with framework
// Call after PD_BUILD_OPERATOR(...)
// Call after PD_BUILD_OP(...)
void RegisterAllCustomOperator();
/////////////////////// Op register Macro /////////////////////////
#define PD_BUILD_OPERATOR(op_name) \
#define PD_BUILD_OP(op_name) \
static ::paddle::OpMetaInfoBuilder __op_meta_info_##__COUNTER__##__ = \
::paddle::OpMetaInfoBuilder(op_name)
......
......@@ -125,8 +125,6 @@ T *Tensor::mutable_data() {
#ifdef PADDLE_WITH_CUDA
case static_cast<int>(PlaceType::kGPU): {
int device_num = platform::GetCurrentDeviceId();
VLOG(1) << "Custom Operator: mutable data cuda device id - "
<< device_num;
return tensor->mutable_data<T>(platform::CUDAPlace(device_num));
}
#endif
......
......@@ -31,7 +31,7 @@ std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype);
// Reuse codes in `relu_op_simple.cc/cu` to register another custom operator
// to test jointly compile multi operators at same time.
PD_BUILD_OPERATOR("relu3")
PD_BUILD_OP("relu3")
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward))
......
......@@ -104,7 +104,7 @@ std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) {
return {x_dtype};
}
PD_BUILD_OPERATOR("relu2")
PD_BUILD_OP("relu2")
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward))
......
......@@ -612,7 +612,7 @@ def parse_op_name_from(sources):
def regex(content):
if USING_NEW_CUSTOM_OP_LOAD_METHOD:
pattern = re.compile(r'BUILD_OPERATOR\(([^,]+),')
pattern = re.compile(r'PD_BUILD_OP\(([^,\)]+)\)')
else:
pattern = re.compile(r'REGISTER_OPERATOR\(([^,]+),')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册