未验证 提交 0f888836 编写于 作者: Z Zeng Jinle 提交者: GitHub

Polish op registry codes (#21561)

* polish infer shape registry, test=develop

* modify some operators registry, test=develop
上级 3d9dee57
...@@ -155,17 +155,41 @@ class OperatorRegistrarRecursive<I, true, ARGS...> { ...@@ -155,17 +155,41 @@ class OperatorRegistrarRecursive<I, true, ARGS...> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kOperator> { struct OpInfoFiller<T, kOperator> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(info->creator_, nullptr,
platform::errors::AlreadyExists(
"OpCreator of %s has been registered", op_type));
info->creator_ = [](const std::string& type, const VariableNameMap& inputs, info->creator_ = [](const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const VariableNameMap& outputs,
const AttributeMap& attrs) { const AttributeMap& attrs) {
return new T(type, inputs, outputs, attrs); return new T(type, inputs, outputs, attrs);
}; };
if (std::is_base_of<OperatorWithKernel, T>::value) {
PADDLE_ENFORCE_EQ(
info->infer_shape_, nullptr,
platform::errors::AlreadyExists(
"Duplicate InferShapeFN of %s has been registered", op_type));
auto* op =
dynamic_cast<OperatorWithKernel*>(info->creator_("", {}, {}, {}));
PADDLE_ENFORCE_NOT_NULL(op, platform::errors::InvalidArgument(
"%s should have kernels", op_type));
info->infer_shape_ = [op](InferShapeContext* ctx) {
op->InferShape(ctx);
};
}
} }
}; };
template <typename T> template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(info->proto_, nullptr,
platform::errors::AlreadyExists(
"OpProto of %s has been registered", op_type));
PADDLE_ENFORCE_EQ(info->checker_, nullptr,
platform::errors::AlreadyExists(
"OpAttrChecker of %s has been registered", op_type));
info->proto_ = new proto::OpProto; info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker(); info->checker_ = new OpAttrChecker();
T maker; T maker;
...@@ -181,6 +205,11 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -181,6 +205,11 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> { struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->grad_op_maker_, nullptr,
platform::errors::AlreadyExists(
"GradOpDescMaker of %s has been registered", op_type));
info->grad_op_maker_ = []( info->grad_op_maker_ = [](
const OpDesc& fwd_op, const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set, const std::unordered_set<std::string>& no_grad_set,
...@@ -199,6 +228,11 @@ struct OpInfoFiller<T, kGradOpDescMaker> { ...@@ -199,6 +228,11 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kGradOpBaseMaker> { struct OpInfoFiller<T, kGradOpBaseMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->dygraph_grad_op_maker_, nullptr,
platform::errors::AlreadyExists(
"GradOpBaseMaker of %s has been registered", op_type));
info->dygraph_grad_op_maker_ = []( info->dygraph_grad_op_maker_ = [](
const imperative::OpBase* fw_op_base, const imperative::OpBase* fw_op_base,
const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_in,
...@@ -212,6 +246,10 @@ struct OpInfoFiller<T, kGradOpBaseMaker> { ...@@ -212,6 +246,10 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kVarTypeInference> { struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_var_type_, nullptr,
platform::errors::AlreadyExists(
"VarTypeInference of %s has been registered", op_type));
info->infer_var_type_ = [](InferVarTypeContext* context) { info->infer_var_type_ = [](InferVarTypeContext* context) {
T inference; T inference;
inference(context); inference(context);
...@@ -222,6 +260,10 @@ struct OpInfoFiller<T, kVarTypeInference> { ...@@ -222,6 +260,10 @@ struct OpInfoFiller<T, kVarTypeInference> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kShapeInference> { struct OpInfoFiller<T, kShapeInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_shape_, nullptr,
platform::errors::AlreadyExists(
"Duplicate InferShapeFN of %s has been registered", op_type));
info->infer_shape_ = [](InferShapeContext* ctx) { info->infer_shape_ = [](InferShapeContext* ctx) {
T inference; T inference;
inference(ctx); inference(ctx);
...@@ -232,6 +274,10 @@ struct OpInfoFiller<T, kShapeInference> { ...@@ -232,6 +274,10 @@ struct OpInfoFiller<T, kShapeInference> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kInplaceOpInference> { struct OpInfoFiller<T, kInplaceOpInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_inplace_, nullptr,
platform::errors::AlreadyExists(
"InplaceOpInference of %s has been registered", op_type));
info->infer_inplace_ = [](bool use_cuda) { info->infer_inplace_ = [](bool use_cuda) {
T infer; T infer;
return infer(use_cuda); return infer(use_cuda);
...@@ -242,6 +288,10 @@ struct OpInfoFiller<T, kInplaceOpInference> { ...@@ -242,6 +288,10 @@ struct OpInfoFiller<T, kInplaceOpInference> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kNoNeedBufferVarsInference> { struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_no_need_buffer_vars_, nullptr,
platform::errors::AlreadyExists(
"NoNeedBufferVarsInference of %s has been registered", op_type));
info->infer_no_need_buffer_vars_.Reset(std::make_shared<T>()); info->infer_no_need_buffer_vars_.Reset(std::make_shared<T>());
} }
}; };
......
...@@ -124,9 +124,23 @@ class InferNoNeedBufferVarsFN { ...@@ -124,9 +124,23 @@ class InferNoNeedBufferVarsFN {
inferer_ = inferer; inferer_ = inferer;
} }
inline bool operator==(std::nullptr_t) const { return inferer_ == nullptr; }
inline bool operator!=(std::nullptr_t) const { return inferer_ != nullptr; }
private: private:
std::shared_ptr<NoNeedBufferVarsInference> inferer_; std::shared_ptr<NoNeedBufferVarsInference> inferer_;
}; };
static inline bool operator==(std::nullptr_t,
const InferNoNeedBufferVarsFN &other) {
return other == nullptr;
}
static inline bool operator!=(std::nullptr_t,
const InferNoNeedBufferVarsFN &other) {
return other != nullptr;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -48,5 +48,31 @@ TEST(test_no_need_buffer_vars_inference, test_dygraph) { ...@@ -48,5 +48,31 @@ TEST(test_no_need_buffer_vars_inference, test_dygraph) {
ASSERT_TRUE(boost::get<bool>(ctx.GetAttr("is_test"))); ASSERT_TRUE(boost::get<bool>(ctx.GetAttr("is_test")));
} }
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(TestNoNeedBufferVarsInferer, "X1", "X2");
TEST(test_no_need_buffer_vars_inference, test_nullptr_comparation) {
InferNoNeedBufferVarsFN infer_fn;
ASSERT_FALSE(static_cast<bool>(infer_fn));
ASSERT_TRUE(!infer_fn);
ASSERT_TRUE(infer_fn == nullptr);
ASSERT_TRUE(nullptr == infer_fn);
ASSERT_FALSE(infer_fn != nullptr);
ASSERT_FALSE(nullptr != infer_fn);
infer_fn.Reset(std::make_shared<TestNoNeedBufferVarsInferer>());
ASSERT_TRUE(static_cast<bool>(infer_fn));
ASSERT_FALSE(!infer_fn);
ASSERT_FALSE(infer_fn == nullptr);
ASSERT_FALSE(nullptr == infer_fn);
ASSERT_TRUE(infer_fn != nullptr);
ASSERT_TRUE(nullptr != infer_fn);
auto no_need_slots =
infer_fn(VariableNameMap{}, VariableNameMap{}, AttributeMap{});
ASSERT_EQ(no_need_slots.size(), 2UL);
ASSERT_EQ(no_need_slots.count("X1"), 1UL);
ASSERT_EQ(no_need_slots.count("X2"), 1UL);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -653,55 +653,6 @@ void OpDesc::Flush() { ...@@ -653,55 +653,6 @@ void OpDesc::Flush() {
} }
} }
static std::once_flag init_infer_shape_funcs;
/**
* NOTE(paddle-dev): Very tricky code here. Maybe we should find a
* better way to register compile-time infershape method gentlely.
*
* Normally, we can register a class derived from InferShapeBase, so that
* we can set the field of `infer_shape_` inside OpInfo when registering op.
*
* However, there is another way we can set the field of `infer_shape_` inside
* OpInfo. Usually, we overload InferShape method of OperatorWithKernel. After
* running the following method InitInferShapeFuncs, `infer_shape_` would be set
* to be the InferShape method of OperatorWithKernel. That is to say, we borrow
* the run-time InferShape method of OperatorWithKernel to be the compile-time
* InferShape method.
*
* However, during compiling time, we may not know inputs, outputs and attrs of
* run-time OperatorWithKernel. So the following code creates a fake
* OperatorWithKernel object. That is why the field info_ of OperatorBase
* would be null.
*/
static void InitInferShapeFuncs() {
std::call_once(init_infer_shape_funcs, [] {
auto &map = OpInfoMap::Instance();
auto &info_map = *map.mutable_map();
for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
auto op_type = kern_pair.first;
auto it = info_map.find(op_type);
PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered",
op_type);
auto &op_info = it->second;
if (op_info.infer_shape_) { // infer_shape has been registered.
continue;
}
auto op = dynamic_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
PADDLE_ENFORCE_NOT_NULL(
op, "InferShapeBase is not registered to Operator %s", op_type);
op_info.infer_shape_ = [op](InferShapeContext *ctx) {
op->InferShape(ctx);
};
}
});
}
void OpDesc::CheckAttrs() { void OpDesc::CheckAttrs() {
PADDLE_ENFORCE(!Type().empty(), PADDLE_ENFORCE(!Type().empty(),
"CheckAttr() can not be called before type is setted."); "CheckAttr() can not be called before type is setted.");
...@@ -718,7 +669,6 @@ void OpDesc::CheckAttrs() { ...@@ -718,7 +669,6 @@ void OpDesc::CheckAttrs() {
void OpDesc::InferShape(const BlockDesc &block) const { void OpDesc::InferShape(const BlockDesc &block) const {
try { try {
VLOG(3) << "CompileTime infer shape on " << Type(); VLOG(3) << "CompileTime infer shape on " << Type();
InitInferShapeFuncs();
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_; auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
PADDLE_ENFORCE(static_cast<bool>(infer_shape), PADDLE_ENFORCE(static_cast<bool>(infer_shape),
"%s's infer_shape has not been registered", this->Type()); "%s's infer_shape has not been registered", this->Type());
......
...@@ -38,17 +38,6 @@ the input dtype, but it's fine if you do so. ...@@ -38,17 +38,6 @@ the input dtype, but it's fine if you do so.
} }
}; };
class CastOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"), "The input of cast op must be set");
PADDLE_ENFORCE(context->HasOutput("Out"),
"The output of cast op must be set");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
template <typename T> template <typename T>
class CastOpGradMaker : public framework::SingleGradOpMaker<T> { class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -71,6 +60,17 @@ class CastOp : public framework::OperatorWithKernel { ...@@ -71,6 +60,17 @@ class CastOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(
context->HasInput("X"), true,
platform::errors::NotFound("The input(X) of cast op must be set"));
PADDLE_ENFORCE_EQ(
context->HasOutput("Out"), true,
platform::errors::NotFound("The output of cast op must be set"));
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
...@@ -88,7 +88,7 @@ using CPU = paddle::platform::CPUDeviceContext; ...@@ -88,7 +88,7 @@ using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cast, ops::CastOp, REGISTER_OPERATOR(cast, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>, ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>, ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastOpInferShape, ops::CastOpProtoMaker); ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>, REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>, ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int>,
......
...@@ -73,9 +73,12 @@ calculated by $%s$ ...@@ -73,9 +73,12 @@ calculated by $%s$
}; };
template <typename OpComment> template <typename OpComment>
class CompareOpInferShape : public framework::InferShapeBase { class CompareOp : public framework::OperatorWithKernel {
public: public:
void operator()(framework::InferShapeContext* context) const override { using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OpComment comment; OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"), "%s operator must has input X", PADDLE_ENFORCE(context->HasInput("X"), "%s operator must has input X",
comment.type); comment.type);
...@@ -89,13 +92,7 @@ class CompareOpInferShape : public framework::InferShapeBase { ...@@ -89,13 +92,7 @@ class CompareOpInferShape : public framework::InferShapeBase {
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
} }
};
class CompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
...@@ -118,9 +115,8 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -118,9 +115,8 @@ class CompareOp : public framework::OperatorWithKernel {
char _##op_type##Comment::type[]{#op_type}; \ char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \ char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \ REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp, \ op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \ ::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); ::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
...@@ -57,10 +57,44 @@ Each element of Out is calculated by %s ...@@ -57,10 +57,44 @@ Each element of Out is calculated by %s
} }
}; };
class LogicalOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
template <typename OpComment>
class UnaryLogicalOp : public LogicalOp {
public:
using LogicalOp::LogicalOp;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE_EQ(
context->HasInput("X"), true,
platform::errors::NotFound("Input(X) of %s operator must not be null",
comment.type));
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
template <typename OpComment> template <typename OpComment>
class BinaryLogicalOpInferShape : public framework::InferShapeBase { class BinaryLogicalOp : public LogicalOp {
public: public:
void operator()(framework::InferShapeContext *context) const override { using LogicalOp::LogicalOp;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment; OpComment comment;
PADDLE_ENFORCE_EQ(context->HasInput("X"), true, PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
"Input(X) of %s operator must not be null", comment.type); "Input(X) of %s operator must not be null", comment.type);
...@@ -84,32 +118,6 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase { ...@@ -84,32 +118,6 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase {
} }
}; };
template <typename OpComment>
class UnaryLogicalOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
"Input(X) of %s operator must not be null", comment.type);
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
class LogicalOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -121,9 +129,8 @@ class LogicalOp : public framework::OperatorWithKernel { ...@@ -121,9 +129,8 @@ class LogicalOp : public framework::OperatorWithKernel {
char _##op_type##Comment::type[]{#op_type}; \ char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \ char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \ REGISTER_OPERATOR( \
op_type, ::paddle::operators::LogicalOp, \ op_type, ::paddle::operators::BinaryLogicalOp<_##op_type##Comment>, \
::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \ ::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \ ::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); ::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -135,9 +142,8 @@ class LogicalOp : public framework::OperatorWithKernel { ...@@ -135,9 +142,8 @@ class LogicalOp : public framework::OperatorWithKernel {
char _##op_type##Comment::type[]{#op_type}; \ char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \ char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \ REGISTER_OPERATOR( \
op_type, ::paddle::operators::LogicalOp, \ op_type, ::paddle::operators::UnaryLogicalOp<_##op_type##Comment>, \
::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \ ::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \ ::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); ::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
...@@ -88,6 +88,13 @@ class ExpandAsGradOp : public framework::OperatorWithKernel { ...@@ -88,6 +88,13 @@ class ExpandAsGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
}; };
template <typename T> template <typename T>
...@@ -108,7 +115,7 @@ class ExpandAsGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -108,7 +115,7 @@ class ExpandAsGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
// DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandAsGradNoNeedBufVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -117,7 +124,8 @@ namespace ops = paddle::operators; ...@@ -117,7 +124,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(expand_as, ops::ExpandAsOp, ops::ExpandAsOpMaker, REGISTER_OPERATOR(expand_as, ops::ExpandAsOp, ops::ExpandAsOpMaker,
ops::ExpandAsGradOpMaker<paddle::framework::OpDesc>, ops::ExpandAsGradOpMaker<paddle::framework::OpDesc>,
ops::ExpandAsGradOpMaker<paddle::imperative::OpBase>); ops::ExpandAsGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(expand_as_grad, ops::ExpandAsGradOp); REGISTER_OPERATOR(expand_as_grad, ops::ExpandAsGradOp,
ops::ExpandAsGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
expand_as, ops::ExpandAsKernel<paddle::platform::CPUDeviceContext, float>, expand_as, ops::ExpandAsKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandAsKernel<paddle::platform::CPUDeviceContext, double>, ops::ExpandAsKernel<paddle::platform::CPUDeviceContext, double>,
......
...@@ -59,9 +59,12 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker { ...@@ -59,9 +59,12 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker {
} }
}; };
class Conv2DFusionOpInferShape : public framework::InferShapeBase { class Conv2DFusionOp : public operators::ConvOp {
public: public:
void operator()(framework::InferShapeContext* ctx) const override { using operators::ConvOp::ConvOp;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input(Input) of ConvOp should not be null."); "Input(Input) of ConvOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true,
...@@ -175,7 +178,7 @@ class Conv2DFusionOpInferShape : public framework::InferShapeBase { ...@@ -175,7 +178,7 @@ class Conv2DFusionOpInferShape : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
conv2d_fusion, ops::ConvOp, ops::Conv2DFusionOpMaker, conv2d_fusion, ops::Conv2DFusionOp, ops::Conv2DFusionOpMaker,
ops::Conv2DFusionOpInferShape, ops::ConvOpInferVarType, ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -45,8 +45,7 @@ class NgraphEngineInferVarType : public framework::VarTypeInference { ...@@ -45,8 +45,7 @@ class NgraphEngineInferVarType : public framework::VarTypeInference {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(ngraph_engine, ops::NgraphEngineOp, ops::NgraphEngineOpMaker, REGISTER_OPERATOR(ngraph_engine, ops::NgraphEngineOp, ops::NgraphEngineOpMaker);
ops::NgraphEngineOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
ngraph_engine, ngraph_engine,
ops::NgraphEngineKernel<paddle::platform::CPUDeviceContext, float>); ops::NgraphEngineKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -20,6 +20,29 @@ class RandomCropOp : public framework::OperatorWithKernel { ...@@ -20,6 +20,29 @@ class RandomCropOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GT(
x_dim.size(), static_cast<int64_t>(shape.size()),
platform::errors::InvalidArgument(
"Rank of Input(X) must be equal to length of Attr(shape)"));
auto out_dim = framework::vectorize<int>(x_dim);
for (size_t i = 1; i <= shape.size(); ++i) {
size_t x_i = x_dim.size() - i;
size_t shape_i = shape.size() - i;
if (ctx->IsRuntime() || (x_dim[x_i] > 0 && shape[shape_i] > 0)) {
PADDLE_ENFORCE_GE(
x_dim[x_i], shape[shape_i],
platform::errors::InvalidArgument(
"Size of Input(X) must be larger than Attr(shape)"));
}
out_dim[x_i] = shape[shape_i];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
...@@ -51,25 +74,6 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,25 +74,6 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class RandomCropOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size()));
auto out_dim = framework::vectorize<int>(x_dim);
for (size_t i = 1; i <= shape.size(); ++i) {
size_t x_i = x_dim.size() - i;
size_t shape_i = shape.size() - i;
if (ctx->IsRuntime() || (x_dim[x_i] > 0 && shape[shape_i] > 0)) {
PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]);
}
out_dim[x_i] = shape[shape_i];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -77,7 +81,6 @@ namespace ops = paddle::operators; ...@@ -77,7 +81,6 @@ namespace ops = paddle::operators;
namespace f = paddle::framework; namespace f = paddle::framework;
REGISTER_OPERATOR( REGISTER_OPERATOR(
random_crop, ops::RandomCropOp, ops::RandomCropOpMaker, random_crop, ops::RandomCropOp, ops::RandomCropOpMaker,
ops::RandomCropOpInferShape,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
...@@ -77,18 +77,13 @@ class SaveOpVarTypeInference : public framework::VarTypeInference { ...@@ -77,18 +77,13 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
} }
}; };
class SaveOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker, REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker,
ops::SaveOpVarTypeInference, ops::SaveOpShapeInference); ops::SaveOpVarTypeInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>, save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -35,9 +35,12 @@ class SeqConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -35,9 +35,12 @@ class SeqConcatOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class SeqConcatShapeInferer : public framework::InferShapeBase { class SequenceConcatOp : public framework::OperatorWithKernel {
public: public:
void operator()(framework::InferShapeContext *context) const override { using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInputs("X"), PADDLE_ENFORCE(context->HasInputs("X"),
"Input(X) of Sequence Concat Op should not be null."); "Input(X) of Sequence Concat Op should not be null.");
PADDLE_ENFORCE(context->HasOutput("Out"), PADDLE_ENFORCE(context->HasOutput("Out"),
...@@ -117,8 +120,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SeqConcatGradNoNeedBufferVarsInference, ...@@ -117,8 +120,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SeqConcatGradNoNeedBufferVarsInference,
namespace op = paddle::operators; namespace op = paddle::operators;
REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel, REGISTER_OPERATOR(sequence_concat, op::SequenceConcatOp, op::SeqConcatOpMaker,
op::SeqConcatOpMaker, op::SeqConcatShapeInferer,
op::SeqConcatGradOpMaker<paddle::framework::OpDesc>, op::SeqConcatGradOpMaker<paddle::framework::OpDesc>,
op::SeqConcatGradOpMaker<paddle::imperative::OpBase>); op::SeqConcatGradOpMaker<paddle::imperative::OpBase>);
template <typename T> template <typename T>
......
...@@ -55,6 +55,6 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference { ...@@ -55,6 +55,6 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp, REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp,
ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker); ops::TensorRTEngineOpMaker);
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册