未验证 提交 126633c5 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp] Split build op marco & polish details (#31229)

* split build op marco & polish details

* revert register api del

* fix other unittest
上级 e8d24b54
...@@ -38,6 +38,8 @@ class PD_DLL_DECL OpMetaInfoHelper; ...@@ -38,6 +38,8 @@ class PD_DLL_DECL OpMetaInfoHelper;
using Tensor = paddle::Tensor; using Tensor = paddle::Tensor;
///////////////// Util Marco Define ////////////////
#define PD_DISABLE_COPY_AND_ASSIGN(classname) \ #define PD_DISABLE_COPY_AND_ASSIGN(classname) \
private: \ private: \
classname(const classname&) = delete; \ classname(const classname&) = delete; \
...@@ -65,6 +67,12 @@ using Tensor = paddle::Tensor; ...@@ -65,6 +67,12 @@ using Tensor = paddle::Tensor;
END_HANDLE_THE_ERROR \ END_HANDLE_THE_ERROR \
} while (0) } while (0)
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
///////////////// Util Define and Function //////////////// ///////////////// Util Define and Function ////////////////
inline std::string Grad(const std::string& var_name) { inline std::string Grad(const std::string& var_name) {
...@@ -288,9 +296,9 @@ class PD_DLL_DECL OpMetaInfo { ...@@ -288,9 +296,9 @@ class PD_DLL_DECL OpMetaInfo {
std::vector<std::string> attrs_; std::vector<std::string> attrs_;
// 2. func info // 2. func info
KernelFunc kernel_fn_; KernelFunc kernel_fn_{nullptr};
InferShapeFunc infer_shape_fn_; InferShapeFunc infer_shape_fn_{nullptr};
InferDtypeFunc infer_dtype_fn_; InferDtypeFunc infer_dtype_fn_{nullptr};
}; };
//////////////// Op Meta Info Map ///////////////// //////////////// Op Meta Info Map /////////////////
...@@ -321,20 +329,22 @@ class PD_DLL_DECL OpMetaInfoMap { ...@@ -321,20 +329,22 @@ class PD_DLL_DECL OpMetaInfoMap {
class PD_DLL_DECL OpMetaInfoBuilder { class PD_DLL_DECL OpMetaInfoBuilder {
public: public:
explicit OpMetaInfoBuilder(std::string&& name); explicit OpMetaInfoBuilder(std::string&& name, size_t index);
OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs); OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs); OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs); OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs);
OpMetaInfoBuilder& SetKernelFn(KernelFunc func); OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func); OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func); OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
OpMetaInfoBuilder& SetBackwardOp(const std::string& bwd_op_name);
private: private:
// Forward Op name // Forward Op name
std::string name_; std::string name_;
// Point to the currently constructed op meta info // ref current info ptr
OpMetaInfo* info_ptr_; OpMetaInfo* info_ptr_;
// The current op meta info index in vector
// - 0: op, 1: grad_op, 2: grad_grad_op
size_t index_;
}; };
/////////////////////// Op register API ///////////////////////// /////////////////////// Op register API /////////////////////////
...@@ -350,14 +360,25 @@ void LoadCustomOperatorLib(const std::string& dso_name); ...@@ -350,14 +360,25 @@ void LoadCustomOperatorLib(const std::string& dso_name);
/////////////////////// Op register Macro ///////////////////////// /////////////////////// Op register Macro /////////////////////////
#define PD_BUILD_OP_WITH_COUNTER(op_name, counter) \ #define PD_BUILD_OP(op_name) \
static ::paddle::OpMetaInfoBuilder __op_meta_info_##counter##__ = \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
::paddle::OpMetaInfoBuilder(op_name) __reg_op__##op_name, "PD_BUILD_OP must be called in global namespace."); \
static ::paddle::OpMetaInfoBuilder __op_meta_info_##op_name##__ = \
#define PD_BUILD_OP_INNER(op_name, counter) \ ::paddle::OpMetaInfoBuilder(#op_name, 0)
PD_BUILD_OP_WITH_COUNTER(op_name, counter)
#define PD_BUILD_GRAD_OP(op_name) \
#define PD_BUILD_OP(op_name) PD_BUILD_OP_INNER(op_name, __COUNTER__) STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_grad_op__##op_name, \
"PD_BUILD_GRAD_OP must be called in global namespace."); \
static ::paddle::OpMetaInfoBuilder __grad_op_meta_info_##op_name##__ = \
::paddle::OpMetaInfoBuilder(#op_name, 1)
#define PD_BUILD_DOUBLE_GRAD_OP(op_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_grad_grad_op__##op_name, \
"PD_BUILD_DOUBLE_GRAD_OP must be called in global namespace."); \
static ::paddle::OpMetaInfoBuilder __grad_grad_op_meta_info_##op_name##__ = \
::paddle::OpMetaInfoBuilder(#op_name, 2)
} // namespace paddle } // namespace paddle
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -62,11 +63,38 @@ OpMetaInfoMap::GetMap() const { ...@@ -62,11 +63,38 @@ OpMetaInfoMap::GetMap() const {
//////////////// Op Meta Info Builder ///////////////// //////////////// Op Meta Info Builder /////////////////
OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name) { OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name, size_t index) {
// 1. member assign
name_ = std::forward<std::string>(name); name_ = std::forward<std::string>(name);
index_ = index;
// 2. check and meta info build
auto& info_vector = OpMetaInfoMap::Instance()[name_]; auto& info_vector = OpMetaInfoMap::Instance()[name_];
// index check
PADDLE_ENFORCE_EQ(
info_vector.size(), index_,
platform::errors::PreconditionNotMet(
"The operator %s's meta info register failed. "
"Please make sure you call marcos as order `PD_BUILD_OP`, "
"`PD_BUILD_GRAD_OP`, `PD_BUILD_DOUBLE_GRAD_OP`.",
name_));
switch (index_) {
case 0:
break;
case 1:
name_ = name_ + "_grad";
break;
case 2:
name_ = name_ + "_grad_grad";
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Not support index `%d` when construct OpMetaInfoBuilder, "
"now only support `0, 1, 2`.",
index_));
}
auto op_meta = OpMetaInfo(name_); auto op_meta = OpMetaInfo(name_);
info_vector.emplace_back(std::move(op_meta)); info_vector.emplace_back(std::move(op_meta));
// 3. get current info ptr
info_ptr_ = &(info_vector.back()); info_ptr_ = &(info_vector.back());
} }
...@@ -93,24 +121,27 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) { ...@@ -93,24 +121,27 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
} }
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) { OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) {
PADDLE_ENFORCE_EQ(
index_, 0UL,
platform::errors::Unimplemented(
"Currently, the InferShapeFn setting of Grad Op is not supported, "
"And backward Tensor `X@GRAD` will use the shape of forward Tensor "
"`X` by default."));
info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func)); info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func));
return *this; return *this;
} }
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) { OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) {
PADDLE_ENFORCE_EQ(
index_, 0UL,
platform::errors::Unimplemented(
"Currently, the InferDtypeFn setting of Grad Op is not supported, "
"And backward Tensor `X@GRAD` will use the dtype of forward Tensor "
"`X` by default."));
info_ptr_->SetInferDtypeFn(std::forward<InferDtypeFunc>(func)); info_ptr_->SetInferDtypeFn(std::forward<InferDtypeFunc>(func));
return *this; return *this;
} }
OpMetaInfoBuilder& OpMetaInfoBuilder::SetBackwardOp(
const std::string& bwd_op_name) {
auto& info_vector = OpMetaInfoMap::Instance()[name_];
auto op_meta = OpMetaInfo(bwd_op_name);
info_vector.emplace_back(std::move(op_meta));
info_ptr_ = &(info_vector.back());
return *this;
}
/////////////////////// Op register API ///////////////////////// /////////////////////// Op register API /////////////////////////
void RegisterAllCustomOperator() { void RegisterAllCustomOperator() {
......
...@@ -153,6 +153,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, ...@@ -153,6 +153,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
} }
VLOG(1) << "Run ComputeFunc."; VLOG(1) << "Run ComputeFunc.";
try {
auto outs = func(custom_ins, custom_attrs); auto outs = func(custom_ins, custom_attrs);
VLOG(1) << "Custom Operator: Share outputs into ExecutionContext."; VLOG(1) << "Custom Operator: Share outputs into ExecutionContext.";
...@@ -160,6 +161,14 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, ...@@ -160,6 +161,14 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
auto* true_out = ctx.Output<Tensor>(outputs[i]); auto* true_out = ctx.Output<Tensor>(outputs[i]);
CustomTensorUtils::ShareDataTo(outs.at(i), true_out); CustomTensorUtils::ShareDataTo(outs.at(i), true_out);
} }
} catch (platform::EnforceNotMet& exception) {
throw std::move(exception);
} catch (std::exception& ex) {
PADDLE_THROW(platform::errors::External("%s", ex.what()));
} catch (...) {
PADDLE_THROW(platform::errors::Fatal(
"Custom operator raises an unknown exception in rumtime."));
}
} }
//////////////////// Operator Define ///////////////// //////////////////// Operator Define /////////////////
...@@ -475,11 +484,34 @@ void RegisterOperatorWithMetaInfo( ...@@ -475,11 +484,34 @@ void RegisterOperatorWithMetaInfo(
op_name, info.proto_->InitializationErrorString())); op_name, info.proto_->InitializationErrorString()));
// InferShape // InferShape
PADDLE_ENFORCE_NOT_NULL( if (infer_shape_func == nullptr) {
infer_shape_func, // use default InferShape
platform::errors::PreconditionNotMet( info.infer_shape_ = [op_inputs, op_outputs](InferShapeContext* ctx) {
"InferShapeFn is nullptr. Need to set the InferShapeFn of custom " PADDLE_ENFORCE_EQ(
op_inputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple inputs. "
"We only allow a custom operator that contains only one input "
"and "
"only one output without setting the InferShapeFn. At this time, "
"the input shape will be directly set to the output shape.\n"
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
PADDLE_ENFORCE_EQ(
op_outputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple outputs. "
"We only allow a custom operator that contains only one input "
"and "
"only one output without setting the InferShapeFn. At this time, "
"the input shape will be directly set to the output shape.\n"
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
VLOG(1) << "Custom Operator: Default InferShape - share ddim.";
ctx->ShareDim(op_inputs[0], op_outputs[0]);
};
} else {
info.infer_shape_ = [op_inputs, op_outputs, info.infer_shape_ = [op_inputs, op_outputs,
infer_shape_func](InferShapeContext* ctx) { infer_shape_func](InferShapeContext* ctx) {
std::vector<std::vector<int64_t>> input_shapes; std::vector<std::vector<int64_t>> input_shapes;
...@@ -496,16 +528,42 @@ void RegisterOperatorWithMetaInfo( ...@@ -496,16 +528,42 @@ void RegisterOperatorWithMetaInfo(
VLOG(1) << "Custom Operator: InferShape - set output ddim."; VLOG(1) << "Custom Operator: InferShape - set output ddim.";
for (size_t i = 0; i < op_outputs.size(); ++i) { for (size_t i = 0; i < op_outputs.size(); ++i) {
ctx->SetOutputDim(op_outputs[i], framework::make_ddim(output_shapes[i])); ctx->SetOutputDim(op_outputs[i],
framework::make_ddim(output_shapes[i]));
} }
}; };
}
// Infer Dtype // Infer Dtype
PADDLE_ENFORCE_NOT_NULL( if (infer_dtype_func == nullptr) {
infer_dtype_func, // use defalut InferDtype
platform::errors::PreconditionNotMet( info.infer_var_type_ = [op_inputs, op_outputs](InferVarTypeContext* ctx) {
"InferDtypeFn is nullptr. Need to set the InferDtypeFn of custom " PADDLE_ENFORCE_EQ(
op_inputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple inputs. "
"We only allow a custom operator that contains only one input "
"and "
"only one output without setting the InferDtypeFn. At this time, "
"the input dtype will be directly set to the output dtype.\n"
"Please set the InferDtypeFn of custom "
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
PADDLE_ENFORCE_EQ(
op_outputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple outputs. "
"We only allow a custom operator that contains only one input "
"and "
"only one output without setting the InferDtypeFn. At this time, "
"the input dtype will be directly set to the output dtype.\n"
"Please set the InferDtypeFn of custom "
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))")); "operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
VLOG(1) << "Custom Operator: InferDtype - share dtype.";
auto dtype = ctx->GetInputDataType(op_inputs[0]);
ctx->SetOutputDataType(op_outputs[0], dtype);
};
} else {
info.infer_var_type_ = [op_inputs, op_outputs, info.infer_var_type_ = [op_inputs, op_outputs,
infer_dtype_func](InferVarTypeContext* ctx) { infer_dtype_func](InferVarTypeContext* ctx) {
std::vector<DataType> input_dtypes; std::vector<DataType> input_dtypes;
...@@ -527,6 +585,7 @@ void RegisterOperatorWithMetaInfo( ...@@ -527,6 +585,7 @@ void RegisterOperatorWithMetaInfo(
CustomTensorUtils::ConvertEnumDTypeToInnerDType(output_dtypes[i])); CustomTensorUtils::ConvertEnumDTypeToInnerDType(output_dtypes[i]));
} }
}; };
}
// Kernel func // Kernel func
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs); RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs);
......
...@@ -150,15 +150,7 @@ std::vector<paddle::Tensor> AttrTestBackward( ...@@ -150,15 +150,7 @@ std::vector<paddle::Tensor> AttrTestBackward(
return {grad_x}; return {grad_x};
} }
std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> x_shape) { PD_BUILD_OP(attr_test)
return {x_shape};
}
std::vector<paddle::DataType> InferDType(paddle::DataType x_dtype) {
return {x_dtype};
}
PD_BUILD_OP("attr_test")
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.Attrs({"bool_attr: bool", .Attrs({"bool_attr: bool",
...@@ -170,10 +162,9 @@ PD_BUILD_OP("attr_test") ...@@ -170,10 +162,9 @@ PD_BUILD_OP("attr_test")
"float_vec_attr: std::vector<float>", "float_vec_attr: std::vector<float>",
"int64_vec_attr: std::vector<int64_t>", "int64_vec_attr: std::vector<int64_t>",
"str_vec_attr: std::vector<std::string>"}) "str_vec_attr: std::vector<std::string>"})
.SetKernelFn(PD_KERNEL(AttrTestForward)) .SetKernelFn(PD_KERNEL(AttrTestForward));
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType)) PD_BUILD_GRAD_OP(attr_test)
.SetBackwardOp("attr_test_grad")
.Inputs({paddle::Grad("Out")}) .Inputs({paddle::Grad("Out")})
.Outputs({paddle::Grad("X")}) .Outputs({paddle::Grad("X")})
.Attrs({"int_attr: int", .Attrs({"int_attr: int",
......
...@@ -96,21 +96,12 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x, ...@@ -96,21 +96,12 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
} }
} }
std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape) { PD_BUILD_OP(custom_relu)
return {x_shape};
}
std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) {
return {x_dtype};
}
PD_BUILD_OP("custom_relu")
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward)) .SetKernelFn(PD_KERNEL(ReluForward));
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) PD_BUILD_GRAD_OP(custom_relu)
.SetBackwardOp("relu2_grad")
.Inputs({"X", "Out", paddle::Grad("Out")}) .Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")}) .Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward)); .SetKernelFn(PD_KERNEL(ReluBackward));
...@@ -25,19 +25,14 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x, ...@@ -25,19 +25,14 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
const paddle::Tensor& out, const paddle::Tensor& out,
const paddle::Tensor& grad_out); const paddle::Tensor& grad_out);
std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape);
std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype);
// Reuse codes in `custom_relu_op.cc/cu` to register another custom operator // Reuse codes in `custom_relu_op.cc/cu` to register another custom operator
// to test jointly compile multi operators at same time. // to test jointly compile multi operators at same time.
PD_BUILD_OP("custom_relu_dup") PD_BUILD_OP(custom_relu_dup)
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward)) .SetKernelFn(PD_KERNEL(ReluForward));
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) PD_BUILD_GRAD_OP(custom_relu_dup)
.SetBackwardOp("relu3_grad")
.Inputs({"X", "Out", paddle::Grad("Out")}) .Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")}) .Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward)); .SetKernelFn(PD_KERNEL(ReluBackward));
...@@ -26,14 +26,6 @@ void assign_cpu_kernel(const data_t* x_data, ...@@ -26,14 +26,6 @@ void assign_cpu_kernel(const data_t* x_data,
} }
} }
std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> x_shape) {
return {x_shape};
}
std::vector<paddle::DataType> InferDType(paddle::DataType x_dtype) {
return {x_dtype};
}
std::vector<paddle::Tensor> DispatchTestInterger(const paddle::Tensor& x) { std::vector<paddle::Tensor> DispatchTestInterger(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU); auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape()); out.reshape(x.shape());
...@@ -47,12 +39,10 @@ std::vector<paddle::Tensor> DispatchTestInterger(const paddle::Tensor& x) { ...@@ -47,12 +39,10 @@ std::vector<paddle::Tensor> DispatchTestInterger(const paddle::Tensor& x) {
return {out}; return {out};
} }
PD_BUILD_OP("dispatch_test_integer") PD_BUILD_OP(dispatch_test_integer)
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestInterger)) .SetKernelFn(PD_KERNEL(DispatchTestInterger));
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
std::vector<paddle::Tensor> DispatchTestComplex(const paddle::Tensor& x) { std::vector<paddle::Tensor> DispatchTestComplex(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU); auto out = paddle::Tensor(paddle::PlaceType::kCPU);
...@@ -67,12 +57,10 @@ std::vector<paddle::Tensor> DispatchTestComplex(const paddle::Tensor& x) { ...@@ -67,12 +57,10 @@ std::vector<paddle::Tensor> DispatchTestComplex(const paddle::Tensor& x) {
return {out}; return {out};
} }
PD_BUILD_OP("dispatch_test_complex") PD_BUILD_OP(dispatch_test_complex)
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestComplex)) .SetKernelFn(PD_KERNEL(DispatchTestComplex));
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
std::vector<paddle::Tensor> DispatchTestFloatAndInteger( std::vector<paddle::Tensor> DispatchTestFloatAndInteger(
const paddle::Tensor& x) { const paddle::Tensor& x) {
...@@ -88,12 +76,10 @@ std::vector<paddle::Tensor> DispatchTestFloatAndInteger( ...@@ -88,12 +76,10 @@ std::vector<paddle::Tensor> DispatchTestFloatAndInteger(
return {out}; return {out};
} }
PD_BUILD_OP("dispatch_test_float_and_integer") PD_BUILD_OP(dispatch_test_float_and_integer)
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger)) .SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger));
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
std::vector<paddle::Tensor> DispatchTestFloatAndComplex( std::vector<paddle::Tensor> DispatchTestFloatAndComplex(
const paddle::Tensor& x) { const paddle::Tensor& x) {
...@@ -109,12 +95,10 @@ std::vector<paddle::Tensor> DispatchTestFloatAndComplex( ...@@ -109,12 +95,10 @@ std::vector<paddle::Tensor> DispatchTestFloatAndComplex(
return {out}; return {out};
} }
PD_BUILD_OP("dispatch_test_float_and_complex") PD_BUILD_OP(dispatch_test_float_and_complex)
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex)) .SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex));
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
std::vector<paddle::Tensor> DispatchTestFloatAndIntegerAndComplex( std::vector<paddle::Tensor> DispatchTestFloatAndIntegerAndComplex(
const paddle::Tensor& x) { const paddle::Tensor& x) {
...@@ -130,9 +114,7 @@ std::vector<paddle::Tensor> DispatchTestFloatAndIntegerAndComplex( ...@@ -130,9 +114,7 @@ std::vector<paddle::Tensor> DispatchTestFloatAndIntegerAndComplex(
return {out}; return {out};
} }
PD_BUILD_OP("dispatch_test_float_and_integer_and_complex") PD_BUILD_OP(dispatch_test_float_and_integer_and_complex)
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex)) .SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex));
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
...@@ -68,7 +68,7 @@ std::vector<paddle::DataType> InferDtype(paddle::DataType x_dtype) { ...@@ -68,7 +68,7 @@ std::vector<paddle::DataType> InferDtype(paddle::DataType x_dtype) {
return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32}; return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32};
} }
PD_BUILD_OP("multi_out") PD_BUILD_OP(multi_out)
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out", "Fake_float64", "ZFake_int32"}) .Outputs({"Out", "Fake_float64", "ZFake_int32"})
.SetKernelFn(PD_KERNEL(MultiOutCPU)) .SetKernelFn(PD_KERNEL(MultiOutCPU))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册