未验证 提交 126633c5 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp] Split build op marco & polish details (#31229)

* split build op marco & polish details

* revert register api del

* fix other unittest
上级 e8d24b54
......@@ -38,6 +38,8 @@ class PD_DLL_DECL OpMetaInfoHelper;
using Tensor = paddle::Tensor;
///////////////// Util Marco Define ////////////////
#define PD_DISABLE_COPY_AND_ASSIGN(classname) \
private: \
classname(const classname&) = delete; \
......@@ -65,6 +67,12 @@ using Tensor = paddle::Tensor;
END_HANDLE_THE_ERROR \
} while (0)
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
///////////////// Util Define and Function ////////////////
inline std::string Grad(const std::string& var_name) {
......@@ -288,9 +296,9 @@ class PD_DLL_DECL OpMetaInfo {
std::vector<std::string> attrs_;
// 2. func info
KernelFunc kernel_fn_;
InferShapeFunc infer_shape_fn_;
InferDtypeFunc infer_dtype_fn_;
KernelFunc kernel_fn_{nullptr};
InferShapeFunc infer_shape_fn_{nullptr};
InferDtypeFunc infer_dtype_fn_{nullptr};
};
//////////////// Op Meta Info Map /////////////////
......@@ -321,20 +329,22 @@ class PD_DLL_DECL OpMetaInfoMap {
class PD_DLL_DECL OpMetaInfoBuilder {
public:
explicit OpMetaInfoBuilder(std::string&& name);
explicit OpMetaInfoBuilder(std::string&& name, size_t index);
OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs);
OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
OpMetaInfoBuilder& SetBackwardOp(const std::string& bwd_op_name);
private:
// Forward Op name
std::string name_;
// Point to the currently constructed op meta info
// ref current info ptr
OpMetaInfo* info_ptr_;
// The current op meta info index in vector
// - 0: op, 1: grad_op, 2: grad_grad_op
size_t index_;
};
/////////////////////// Op register API /////////////////////////
......@@ -350,14 +360,25 @@ void LoadCustomOperatorLib(const std::string& dso_name);
/////////////////////// Op register Macro /////////////////////////
#define PD_BUILD_OP_WITH_COUNTER(op_name, counter) \
static ::paddle::OpMetaInfoBuilder __op_meta_info_##counter##__ = \
::paddle::OpMetaInfoBuilder(op_name)
#define PD_BUILD_OP_INNER(op_name, counter) \
PD_BUILD_OP_WITH_COUNTER(op_name, counter)
#define PD_BUILD_OP(op_name) PD_BUILD_OP_INNER(op_name, __COUNTER__)
#define PD_BUILD_OP(op_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_name, "PD_BUILD_OP must be called in global namespace."); \
static ::paddle::OpMetaInfoBuilder __op_meta_info_##op_name##__ = \
::paddle::OpMetaInfoBuilder(#op_name, 0)
#define PD_BUILD_GRAD_OP(op_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_grad_op__##op_name, \
"PD_BUILD_GRAD_OP must be called in global namespace."); \
static ::paddle::OpMetaInfoBuilder __grad_op_meta_info_##op_name##__ = \
::paddle::OpMetaInfoBuilder(#op_name, 1)
#define PD_BUILD_DOUBLE_GRAD_OP(op_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_grad_grad_op__##op_name, \
"PD_BUILD_DOUBLE_GRAD_OP must be called in global namespace."); \
static ::paddle::OpMetaInfoBuilder __grad_grad_op_meta_info_##op_name##__ = \
::paddle::OpMetaInfoBuilder(#op_name, 2)
} // namespace paddle
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -62,11 +63,38 @@ OpMetaInfoMap::GetMap() const {
//////////////// Op Meta Info Builder /////////////////
OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name) {
OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name, size_t index) {
// 1. member assign
name_ = std::forward<std::string>(name);
index_ = index;
// 2. check and meta info build
auto& info_vector = OpMetaInfoMap::Instance()[name_];
// index check
PADDLE_ENFORCE_EQ(
info_vector.size(), index_,
platform::errors::PreconditionNotMet(
"The operator %s's meta info register failed. "
"Please make sure you call marcos as order `PD_BUILD_OP`, "
"`PD_BUILD_GRAD_OP`, `PD_BUILD_DOUBLE_GRAD_OP`.",
name_));
switch (index_) {
case 0:
break;
case 1:
name_ = name_ + "_grad";
break;
case 2:
name_ = name_ + "_grad_grad";
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Not support index `%d` when construct OpMetaInfoBuilder, "
"now only support `0, 1, 2`.",
index_));
}
auto op_meta = OpMetaInfo(name_);
info_vector.emplace_back(std::move(op_meta));
// 3. get current info ptr
info_ptr_ = &(info_vector.back());
}
......@@ -93,24 +121,27 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) {
PADDLE_ENFORCE_EQ(
index_, 0UL,
platform::errors::Unimplemented(
"Currently, the InferShapeFn setting of Grad Op is not supported, "
"And backward Tensor `X@GRAD` will use the shape of forward Tensor "
"`X` by default."));
info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func));
return *this;
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) {
PADDLE_ENFORCE_EQ(
index_, 0UL,
platform::errors::Unimplemented(
"Currently, the InferDtypeFn setting of Grad Op is not supported, "
"And backward Tensor `X@GRAD` will use the dtype of forward Tensor "
"`X` by default."));
info_ptr_->SetInferDtypeFn(std::forward<InferDtypeFunc>(func));
return *this;
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetBackwardOp(
const std::string& bwd_op_name) {
auto& info_vector = OpMetaInfoMap::Instance()[name_];
auto op_meta = OpMetaInfo(bwd_op_name);
info_vector.emplace_back(std::move(op_meta));
info_ptr_ = &(info_vector.back());
return *this;
}
/////////////////////// Op register API /////////////////////////
void RegisterAllCustomOperator() {
......
......@@ -153,6 +153,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
}
VLOG(1) << "Run ComputeFunc.";
try {
auto outs = func(custom_ins, custom_attrs);
VLOG(1) << "Custom Operator: Share outputs into ExecutionContext.";
......@@ -160,6 +161,14 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
auto* true_out = ctx.Output<Tensor>(outputs[i]);
CustomTensorUtils::ShareDataTo(outs.at(i), true_out);
}
} catch (platform::EnforceNotMet& exception) {
throw std::move(exception);
} catch (std::exception& ex) {
PADDLE_THROW(platform::errors::External("%s", ex.what()));
} catch (...) {
PADDLE_THROW(platform::errors::Fatal(
"Custom operator raises an unknown exception in rumtime."));
}
}
//////////////////// Operator Define /////////////////
......@@ -475,11 +484,34 @@ void RegisterOperatorWithMetaInfo(
op_name, info.proto_->InitializationErrorString()));
// InferShape
PADDLE_ENFORCE_NOT_NULL(
infer_shape_func,
platform::errors::PreconditionNotMet(
"InferShapeFn is nullptr. Need to set the InferShapeFn of custom "
if (infer_shape_func == nullptr) {
// use default InferShape
info.infer_shape_ = [op_inputs, op_outputs](InferShapeContext* ctx) {
PADDLE_ENFORCE_EQ(
op_inputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple inputs. "
"We only allow a custom operator that contains only one input "
"and "
"only one output without setting the InferShapeFn. At this time, "
"the input shape will be directly set to the output shape.\n"
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
PADDLE_ENFORCE_EQ(
op_outputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple outputs. "
"We only allow a custom operator that contains only one input "
"and "
"only one output without setting the InferShapeFn. At this time, "
"the input shape will be directly set to the output shape.\n"
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
VLOG(1) << "Custom Operator: Default InferShape - share ddim.";
ctx->ShareDim(op_inputs[0], op_outputs[0]);
};
} else {
info.infer_shape_ = [op_inputs, op_outputs,
infer_shape_func](InferShapeContext* ctx) {
std::vector<std::vector<int64_t>> input_shapes;
......@@ -496,16 +528,42 @@ void RegisterOperatorWithMetaInfo(
VLOG(1) << "Custom Operator: InferShape - set output ddim.";
for (size_t i = 0; i < op_outputs.size(); ++i) {
ctx->SetOutputDim(op_outputs[i], framework::make_ddim(output_shapes[i]));
ctx->SetOutputDim(op_outputs[i],
framework::make_ddim(output_shapes[i]));
}
};
}
// Infer Dtype
PADDLE_ENFORCE_NOT_NULL(
infer_dtype_func,
platform::errors::PreconditionNotMet(
"InferDtypeFn is nullptr. Need to set the InferDtypeFn of custom "
if (infer_dtype_func == nullptr) {
// use defalut InferDtype
info.infer_var_type_ = [op_inputs, op_outputs](InferVarTypeContext* ctx) {
PADDLE_ENFORCE_EQ(
op_inputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple inputs. "
"We only allow a custom operator that contains only one input "
"and "
"only one output without setting the InferDtypeFn. At this time, "
"the input dtype will be directly set to the output dtype.\n"
"Please set the InferDtypeFn of custom "
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
PADDLE_ENFORCE_EQ(
op_outputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple outputs. "
"We only allow a custom operator that contains only one input "
"and "
"only one output without setting the InferDtypeFn. At this time, "
"the input dtype will be directly set to the output dtype.\n"
"Please set the InferDtypeFn of custom "
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
VLOG(1) << "Custom Operator: InferDtype - share dtype.";
auto dtype = ctx->GetInputDataType(op_inputs[0]);
ctx->SetOutputDataType(op_outputs[0], dtype);
};
} else {
info.infer_var_type_ = [op_inputs, op_outputs,
infer_dtype_func](InferVarTypeContext* ctx) {
std::vector<DataType> input_dtypes;
......@@ -527,6 +585,7 @@ void RegisterOperatorWithMetaInfo(
CustomTensorUtils::ConvertEnumDTypeToInnerDType(output_dtypes[i]));
}
};
}
// Kernel func
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs);
......
......@@ -150,15 +150,7 @@ std::vector<paddle::Tensor> AttrTestBackward(
return {grad_x};
}
std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> x_shape) {
return {x_shape};
}
std::vector<paddle::DataType> InferDType(paddle::DataType x_dtype) {
return {x_dtype};
}
PD_BUILD_OP("attr_test")
PD_BUILD_OP(attr_test)
.Inputs({"X"})
.Outputs({"Out"})
.Attrs({"bool_attr: bool",
......@@ -170,10 +162,9 @@ PD_BUILD_OP("attr_test")
"float_vec_attr: std::vector<float>",
"int64_vec_attr: std::vector<int64_t>",
"str_vec_attr: std::vector<std::string>"})
.SetKernelFn(PD_KERNEL(AttrTestForward))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType))
.SetBackwardOp("attr_test_grad")
.SetKernelFn(PD_KERNEL(AttrTestForward));
PD_BUILD_GRAD_OP(attr_test)
.Inputs({paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.Attrs({"int_attr: int",
......
......@@ -96,21 +96,12 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
}
}
std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape) {
return {x_shape};
}
std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) {
return {x_dtype};
}
PD_BUILD_OP("custom_relu")
PD_BUILD_OP(custom_relu)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
.SetBackwardOp("relu2_grad")
.SetKernelFn(PD_KERNEL(ReluForward));
PD_BUILD_GRAD_OP(custom_relu)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward));
......@@ -25,19 +25,14 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out);
std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape);
std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype);
// Reuse codes in `custom_relu_op.cc/cu` to register another custom operator
// to test jointly compile multi operators at same time.
PD_BUILD_OP("custom_relu_dup")
PD_BUILD_OP(custom_relu_dup)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
.SetBackwardOp("relu3_grad")
.SetKernelFn(PD_KERNEL(ReluForward));
PD_BUILD_GRAD_OP(custom_relu_dup)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward));
......@@ -26,14 +26,6 @@ void assign_cpu_kernel(const data_t* x_data,
}
}
std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> x_shape) {
return {x_shape};
}
std::vector<paddle::DataType> InferDType(paddle::DataType x_dtype) {
return {x_dtype};
}
std::vector<paddle::Tensor> DispatchTestInterger(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
......@@ -47,12 +39,10 @@ std::vector<paddle::Tensor> DispatchTestInterger(const paddle::Tensor& x) {
return {out};
}
PD_BUILD_OP("dispatch_test_integer")
PD_BUILD_OP(dispatch_test_integer)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestInterger))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
.SetKernelFn(PD_KERNEL(DispatchTestInterger));
std::vector<paddle::Tensor> DispatchTestComplex(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
......@@ -67,12 +57,10 @@ std::vector<paddle::Tensor> DispatchTestComplex(const paddle::Tensor& x) {
return {out};
}
PD_BUILD_OP("dispatch_test_complex")
PD_BUILD_OP(dispatch_test_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestComplex))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
.SetKernelFn(PD_KERNEL(DispatchTestComplex));
std::vector<paddle::Tensor> DispatchTestFloatAndInteger(
const paddle::Tensor& x) {
......@@ -88,12 +76,10 @@ std::vector<paddle::Tensor> DispatchTestFloatAndInteger(
return {out};
}
PD_BUILD_OP("dispatch_test_float_and_integer")
PD_BUILD_OP(dispatch_test_float_and_integer)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger));
std::vector<paddle::Tensor> DispatchTestFloatAndComplex(
const paddle::Tensor& x) {
......@@ -109,12 +95,10 @@ std::vector<paddle::Tensor> DispatchTestFloatAndComplex(
return {out};
}
PD_BUILD_OP("dispatch_test_float_and_complex")
PD_BUILD_OP(dispatch_test_float_and_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex));
std::vector<paddle::Tensor> DispatchTestFloatAndIntegerAndComplex(
const paddle::Tensor& x) {
......@@ -130,9 +114,7 @@ std::vector<paddle::Tensor> DispatchTestFloatAndIntegerAndComplex(
return {out};
}
PD_BUILD_OP("dispatch_test_float_and_integer_and_complex")
PD_BUILD_OP(dispatch_test_float_and_integer_and_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex));
......@@ -68,7 +68,7 @@ std::vector<paddle::DataType> InferDtype(paddle::DataType x_dtype) {
return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32};
}
PD_BUILD_OP("multi_out")
PD_BUILD_OP(multi_out)
.Inputs({"X"})
.Outputs({"Out", "Fake_float64", "ZFake_int32"})
.SetKernelFn(PD_KERNEL(MultiOutCPU))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册