未验证 提交 75f81233 编写于 作者: C Chen Weihang 提交者: GitHub

fix regex error & simplify marco name (#31031)

上级 f0ee1592
...@@ -290,12 +290,12 @@ class OpMetaInfoBuilder { ...@@ -290,12 +290,12 @@ class OpMetaInfoBuilder {
/////////////////////// Op register API ///////////////////////// /////////////////////// Op register API /////////////////////////
// For inference: compile directly with framework // For inference: compile directly with framework
// Call after PD_BUILD_OPERATOR(...) // Call after PD_BUILD_OP(...)
void RegisterAllCustomOperator(); void RegisterAllCustomOperator();
/////////////////////// Op register Macro ///////////////////////// /////////////////////// Op register Macro /////////////////////////
#define PD_BUILD_OPERATOR(op_name) \ #define PD_BUILD_OP(op_name) \
static ::paddle::OpMetaInfoBuilder __op_meta_info_##__COUNTER__##__ = \ static ::paddle::OpMetaInfoBuilder __op_meta_info_##__COUNTER__##__ = \
::paddle::OpMetaInfoBuilder(op_name) ::paddle::OpMetaInfoBuilder(op_name)
......
...@@ -125,8 +125,6 @@ T *Tensor::mutable_data() { ...@@ -125,8 +125,6 @@ T *Tensor::mutable_data() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
case static_cast<int>(PlaceType::kGPU): { case static_cast<int>(PlaceType::kGPU): {
int device_num = platform::GetCurrentDeviceId(); int device_num = platform::GetCurrentDeviceId();
VLOG(1) << "Custom Operator: mutable data cuda device id - "
<< device_num;
return tensor->mutable_data<T>(platform::CUDAPlace(device_num)); return tensor->mutable_data<T>(platform::CUDAPlace(device_num));
} }
#endif #endif
......
...@@ -31,7 +31,7 @@ std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype); ...@@ -31,7 +31,7 @@ std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype);
// Reuse codes in `relu_op_simple.cc/cu` to register another custom operator // Reuse codes in `relu_op_simple.cc/cu` to register another custom operator
// to test jointly compile multi operators at same time. // to test jointly compile multi operators at same time.
PD_BUILD_OPERATOR("relu3") PD_BUILD_OP("relu3")
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward)) .SetKernelFn(PD_KERNEL(ReluForward))
......
...@@ -104,7 +104,7 @@ std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) { ...@@ -104,7 +104,7 @@ std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) {
return {x_dtype}; return {x_dtype};
} }
PD_BUILD_OPERATOR("relu2") PD_BUILD_OP("relu2")
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward)) .SetKernelFn(PD_KERNEL(ReluForward))
......
...@@ -612,7 +612,7 @@ def parse_op_name_from(sources): ...@@ -612,7 +612,7 @@ def parse_op_name_from(sources):
def regex(content): def regex(content):
if USING_NEW_CUSTOM_OP_LOAD_METHOD: if USING_NEW_CUSTOM_OP_LOAD_METHOD:
pattern = re.compile(r'BUILD_OPERATOR\(([^,]+),') pattern = re.compile(r'PD_BUILD_OP\(([^,\)]+)\)')
else: else:
pattern = re.compile(r'REGISTER_OPERATOR\(([^,]+),') pattern = re.compile(r'REGISTER_OPERATOR\(([^,]+),')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册