未验证 提交 0f888836 编写于 作者: Z Zeng Jinle 提交者: GitHub

Polish op registry codes (#21561)

* polish infer shape registry, test=develop

* modify some operators registry, test=develop
上级 3d9dee57
......@@ -155,17 +155,41 @@ class OperatorRegistrarRecursive<I, true, ARGS...> {
template <typename T>
struct OpInfoFiller<T, kOperator> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(info->creator_, nullptr,
platform::errors::AlreadyExists(
"OpCreator of %s has been registered", op_type));
info->creator_ = [](const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs) {
return new T(type, inputs, outputs, attrs);
};
if (std::is_base_of<OperatorWithKernel, T>::value) {
PADDLE_ENFORCE_EQ(
info->infer_shape_, nullptr,
platform::errors::AlreadyExists(
"Duplicate InferShapeFN of %s has been registered", op_type));
auto* op =
dynamic_cast<OperatorWithKernel*>(info->creator_("", {}, {}, {}));
PADDLE_ENFORCE_NOT_NULL(op, platform::errors::InvalidArgument(
"%s should have kernels", op_type));
info->infer_shape_ = [op](InferShapeContext* ctx) {
op->InferShape(ctx);
};
}
}
};
template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(info->proto_, nullptr,
platform::errors::AlreadyExists(
"OpProto of %s has been registered", op_type));
PADDLE_ENFORCE_EQ(info->checker_, nullptr,
platform::errors::AlreadyExists(
"OpAttrChecker of %s has been registered", op_type));
info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker();
T maker;
......@@ -181,6 +205,11 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->grad_op_maker_, nullptr,
platform::errors::AlreadyExists(
"GradOpDescMaker of %s has been registered", op_type));
info->grad_op_maker_ = [](
const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
......@@ -199,6 +228,11 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T>
struct OpInfoFiller<T, kGradOpBaseMaker> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->dygraph_grad_op_maker_, nullptr,
platform::errors::AlreadyExists(
"GradOpBaseMaker of %s has been registered", op_type));
info->dygraph_grad_op_maker_ = [](
const imperative::OpBase* fw_op_base,
const imperative::NameVarBaseMap& var_base_map_in,
......@@ -212,6 +246,10 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
template <typename T>
struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_var_type_, nullptr,
platform::errors::AlreadyExists(
"VarTypeInference of %s has been registered", op_type));
info->infer_var_type_ = [](InferVarTypeContext* context) {
T inference;
inference(context);
......@@ -222,6 +260,10 @@ struct OpInfoFiller<T, kVarTypeInference> {
template <typename T>
struct OpInfoFiller<T, kShapeInference> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_shape_, nullptr,
platform::errors::AlreadyExists(
"Duplicate InferShapeFN of %s has been registered", op_type));
info->infer_shape_ = [](InferShapeContext* ctx) {
T inference;
inference(ctx);
......@@ -232,6 +274,10 @@ struct OpInfoFiller<T, kShapeInference> {
template <typename T>
struct OpInfoFiller<T, kInplaceOpInference> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_inplace_, nullptr,
platform::errors::AlreadyExists(
"InplaceOpInference of %s has been registered", op_type));
info->infer_inplace_ = [](bool use_cuda) {
T infer;
return infer(use_cuda);
......@@ -242,6 +288,10 @@ struct OpInfoFiller<T, kInplaceOpInference> {
template <typename T>
struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_no_need_buffer_vars_, nullptr,
platform::errors::AlreadyExists(
"NoNeedBufferVarsInference of %s has been registered", op_type));
info->infer_no_need_buffer_vars_.Reset(std::make_shared<T>());
}
};
......
......@@ -124,9 +124,23 @@ class InferNoNeedBufferVarsFN {
inferer_ = inferer;
}
inline bool operator==(std::nullptr_t) const { return inferer_ == nullptr; }
inline bool operator!=(std::nullptr_t) const { return inferer_ != nullptr; }
private:
std::shared_ptr<NoNeedBufferVarsInference> inferer_;
};
static inline bool operator==(std::nullptr_t,
const InferNoNeedBufferVarsFN &other) {
return other == nullptr;
}
static inline bool operator!=(std::nullptr_t,
const InferNoNeedBufferVarsFN &other) {
return other != nullptr;
}
} // namespace framework
} // namespace paddle
......@@ -48,5 +48,31 @@ TEST(test_no_need_buffer_vars_inference, test_dygraph) {
ASSERT_TRUE(boost::get<bool>(ctx.GetAttr("is_test")));
}
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(TestNoNeedBufferVarsInferer, "X1", "X2");
TEST(test_no_need_buffer_vars_inference, test_nullptr_comparation) {
InferNoNeedBufferVarsFN infer_fn;
ASSERT_FALSE(static_cast<bool>(infer_fn));
ASSERT_TRUE(!infer_fn);
ASSERT_TRUE(infer_fn == nullptr);
ASSERT_TRUE(nullptr == infer_fn);
ASSERT_FALSE(infer_fn != nullptr);
ASSERT_FALSE(nullptr != infer_fn);
infer_fn.Reset(std::make_shared<TestNoNeedBufferVarsInferer>());
ASSERT_TRUE(static_cast<bool>(infer_fn));
ASSERT_FALSE(!infer_fn);
ASSERT_FALSE(infer_fn == nullptr);
ASSERT_FALSE(nullptr == infer_fn);
ASSERT_TRUE(infer_fn != nullptr);
ASSERT_TRUE(nullptr != infer_fn);
auto no_need_slots =
infer_fn(VariableNameMap{}, VariableNameMap{}, AttributeMap{});
ASSERT_EQ(no_need_slots.size(), 2UL);
ASSERT_EQ(no_need_slots.count("X1"), 1UL);
ASSERT_EQ(no_need_slots.count("X2"), 1UL);
}
} // namespace framework
} // namespace paddle
......@@ -653,55 +653,6 @@ void OpDesc::Flush() {
}
}
static std::once_flag init_infer_shape_funcs;
/**
* NOTE(paddle-dev): Very tricky code here. Maybe we should find a
* better way to register compile-time infershape method gentlely.
*
* Normally, we can register a class derived from InferShapeBase, so that
* we can set the field of `infer_shape_` inside OpInfo when registering op.
*
* However, there is another way we can set the field of `infer_shape_` inside
* OpInfo. Usually, we overload InferShape method of OperatorWithKernel. After
* running the following method InitInferShapeFuncs, `infer_shape_` would be set
* to be the InferShape method of OperatorWithKernel. That is to say, we borrow
* the run-time InferShape method of OperatorWithKernel to be the compile-time
* InferShape method.
*
* However, during compiling time, we may not know inputs, outputs and attrs of
* run-time OperatorWithKernel. So the following code creates a fake
* OperatorWithKernel object. That is why the field info_ of OperatorBase
* would be null.
*/
static void InitInferShapeFuncs() {
std::call_once(init_infer_shape_funcs, [] {
auto &map = OpInfoMap::Instance();
auto &info_map = *map.mutable_map();
for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
auto op_type = kern_pair.first;
auto it = info_map.find(op_type);
PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered",
op_type);
auto &op_info = it->second;
if (op_info.infer_shape_) { // infer_shape has been registered.
continue;
}
auto op = dynamic_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
PADDLE_ENFORCE_NOT_NULL(
op, "InferShapeBase is not registered to Operator %s", op_type);
op_info.infer_shape_ = [op](InferShapeContext *ctx) {
op->InferShape(ctx);
};
}
});
}
void OpDesc::CheckAttrs() {
PADDLE_ENFORCE(!Type().empty(),
"CheckAttr() can not be called before type is setted.");
......@@ -718,7 +669,6 @@ void OpDesc::CheckAttrs() {
void OpDesc::InferShape(const BlockDesc &block) const {
try {
VLOG(3) << "CompileTime infer shape on " << Type();
InitInferShapeFuncs();
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
PADDLE_ENFORCE(static_cast<bool>(infer_shape),
"%s's infer_shape has not been registered", this->Type());
......
......@@ -38,17 +38,6 @@ the input dtype, but it's fine if you do so.
}
};
class CastOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"), "The input of cast op must be set");
PADDLE_ENFORCE(context->HasOutput("Out"),
"The output of cast op must be set");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
template <typename T>
class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -71,6 +60,17 @@ class CastOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(
context->HasInput("X"), true,
platform::errors::NotFound("The input(X) of cast op must be set"));
PADDLE_ENFORCE_EQ(
context->HasOutput("Out"), true,
platform::errors::NotFound("The output of cast op must be set"));
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
......@@ -88,7 +88,7 @@ using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cast, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastOpInferShape, ops::CastOpProtoMaker);
ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>,
......
......@@ -73,9 +73,12 @@ calculated by $%s$
};
template <typename OpComment>
class CompareOpInferShape : public framework::InferShapeBase {
class CompareOp : public framework::OperatorWithKernel {
public:
void operator()(framework::InferShapeContext* context) const override {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"), "%s operator must has input X",
comment.type);
......@@ -89,13 +92,7 @@ class CompareOpInferShape : public framework::InferShapeBase {
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
class CompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
......@@ -118,9 +115,8 @@ class CompareOp : public framework::OperatorWithKernel {
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp, \
op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
......@@ -57,10 +57,44 @@ Each element of Out is calculated by %s
}
};
class LogicalOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
template <typename OpComment>
class UnaryLogicalOp : public LogicalOp {
public:
using LogicalOp::LogicalOp;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE_EQ(
context->HasInput("X"), true,
platform::errors::NotFound("Input(X) of %s operator must not be null",
comment.type));
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
template <typename OpComment>
class BinaryLogicalOpInferShape : public framework::InferShapeBase {
class BinaryLogicalOp : public LogicalOp {
public:
void operator()(framework::InferShapeContext *context) const override {
using LogicalOp::LogicalOp;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
"Input(X) of %s operator must not be null", comment.type);
......@@ -84,32 +118,6 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase {
}
};
template <typename OpComment>
class UnaryLogicalOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
"Input(X) of %s operator must not be null", comment.type);
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
class LogicalOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
} // namespace operators
} // namespace paddle
......@@ -121,9 +129,8 @@ class LogicalOp : public framework::OperatorWithKernel {
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::LogicalOp, \
op_type, ::paddle::operators::BinaryLogicalOp<_##op_type##Comment>, \
::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......@@ -135,9 +142,8 @@ class LogicalOp : public framework::OperatorWithKernel {
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::LogicalOp, \
op_type, ::paddle::operators::UnaryLogicalOp<_##op_type##Comment>, \
::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
......@@ -88,6 +88,13 @@ class ExpandAsGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
......@@ -108,7 +115,7 @@ class ExpandAsGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
// DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandAsGradNoNeedBufVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -117,7 +124,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(expand_as, ops::ExpandAsOp, ops::ExpandAsOpMaker,
ops::ExpandAsGradOpMaker<paddle::framework::OpDesc>,
ops::ExpandAsGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(expand_as_grad, ops::ExpandAsGradOp);
REGISTER_OPERATOR(expand_as_grad, ops::ExpandAsGradOp,
ops::ExpandAsGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
expand_as, ops::ExpandAsKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandAsKernel<paddle::platform::CPUDeviceContext, double>,
......
......@@ -59,9 +59,12 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker {
}
};
class Conv2DFusionOpInferShape : public framework::InferShapeBase {
class Conv2DFusionOp : public operators::ConvOp {
public:
void operator()(framework::InferShapeContext* ctx) const override {
using operators::ConvOp::ConvOp;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input(Input) of ConvOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true,
......@@ -175,7 +178,7 @@ class Conv2DFusionOpInferShape : public framework::InferShapeBase {
namespace ops = paddle::operators;
REGISTER_OPERATOR(
conv2d_fusion, ops::ConvOp, ops::Conv2DFusionOpMaker,
ops::Conv2DFusionOpInferShape, ops::ConvOpInferVarType,
conv2d_fusion, ops::Conv2DFusionOp, ops::Conv2DFusionOpMaker,
ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......@@ -45,8 +45,7 @@ class NgraphEngineInferVarType : public framework::VarTypeInference {
namespace ops = paddle::operators;
REGISTER_OPERATOR(ngraph_engine, ops::NgraphEngineOp, ops::NgraphEngineOpMaker,
ops::NgraphEngineOpMaker);
REGISTER_OPERATOR(ngraph_engine, ops::NgraphEngineOp, ops::NgraphEngineOpMaker);
REGISTER_OP_CPU_KERNEL(
ngraph_engine,
ops::NgraphEngineKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -20,6 +20,29 @@ class RandomCropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GT(
x_dim.size(), static_cast<int64_t>(shape.size()),
platform::errors::InvalidArgument(
"Rank of Input(X) must be equal to length of Attr(shape)"));
auto out_dim = framework::vectorize<int>(x_dim);
for (size_t i = 1; i <= shape.size(); ++i) {
size_t x_i = x_dim.size() - i;
size_t shape_i = shape.size() - i;
if (ctx->IsRuntime() || (x_dim[x_i] > 0 && shape[shape_i] > 0)) {
PADDLE_ENFORCE_GE(
x_dim[x_i], shape[shape_i],
platform::errors::InvalidArgument(
"Size of Input(X) must be larger than Attr(shape)"));
}
out_dim[x_i] = shape[shape_i];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
......@@ -51,25 +74,6 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class RandomCropOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size()));
auto out_dim = framework::vectorize<int>(x_dim);
for (size_t i = 1; i <= shape.size(); ++i) {
size_t x_i = x_dim.size() - i;
size_t shape_i = shape.size() - i;
if (ctx->IsRuntime() || (x_dim[x_i] > 0 && shape[shape_i] > 0)) {
PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]);
}
out_dim[x_i] = shape[shape_i];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
}
};
} // namespace operators
} // namespace paddle
......@@ -77,7 +81,6 @@ namespace ops = paddle::operators;
namespace f = paddle::framework;
REGISTER_OPERATOR(
random_crop, ops::RandomCropOp, ops::RandomCropOpMaker,
ops::RandomCropOpInferShape,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
......@@ -77,18 +77,13 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
}
};
class SaveOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker,
ops::SaveOpVarTypeInference, ops::SaveOpShapeInference);
ops::SaveOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -35,9 +35,12 @@ class SeqConcatOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class SeqConcatShapeInferer : public framework::InferShapeBase {
class SequenceConcatOp : public framework::OperatorWithKernel {
public:
void operator()(framework::InferShapeContext *context) const override {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInputs("X"),
"Input(X) of Sequence Concat Op should not be null.");
PADDLE_ENFORCE(context->HasOutput("Out"),
......@@ -117,8 +120,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SeqConcatGradNoNeedBufferVarsInference,
namespace op = paddle::operators;
REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel,
op::SeqConcatOpMaker, op::SeqConcatShapeInferer,
REGISTER_OPERATOR(sequence_concat, op::SequenceConcatOp, op::SeqConcatOpMaker,
op::SeqConcatGradOpMaker<paddle::framework::OpDesc>,
op::SeqConcatGradOpMaker<paddle::imperative::OpBase>);
template <typename T>
......
......@@ -55,6 +55,6 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference {
namespace ops = paddle::operators;
REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp,
ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker);
ops::TensorRTEngineOpMaker);
#endif // PADDLE_WITH_CUDA
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册