From 6c6e6385507cfcf658e6a6f3ccc39e0ac353e06a Mon Sep 17 00:00:00 2001 From: chengduo Date: Fri, 9 Nov 2018 17:37:15 +0800 Subject: [PATCH] Add InferVarType for some op (#14201) * add_infer_var_type test=develop * InferVarTypeHelper-> VarTypeInferenceHelper test=develop * PassInputTypeAndDTypeOnOutput test=develop * follow comment test=develop --- paddle/fluid/framework/operator.cc | 2 ++ paddle/fluid/framework/var_type_inference.h | 25 +++++++++++++++++++++ paddle/fluid/operators/activation_op.cc | 16 +++++-------- paddle/fluid/operators/batch_norm_op.cc | 11 ++++++++- paddle/fluid/operators/conv_op.cc | 12 ++++++++++ paddle/fluid/operators/cross_entropy_op.cc | 11 +++++++++ paddle/fluid/operators/elementwise_op.h | 16 +++++-------- paddle/fluid/operators/mean_op.cc | 21 +++++++++++++++-- paddle/fluid/operators/mul_op.cc | 11 ++++++++- paddle/fluid/operators/pool_op.cc | 18 +++++++++++---- paddle/fluid/operators/softmax_op.cc | 10 ++++++++- 11 files changed, 124 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0506907ab5..5624878d43 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -259,6 +259,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { if (row_size >= 0) { ss << "[row_size=" << row_size << "]"; } + std::string dtype = GetDtype(*scope, output.second[i]); + ss << ":" << dtype; ss << "[" << GetDims(*scope, var_name, true) << "]"; ss << "(" << GetLoD(*scope, var_name) << ")"; } diff --git a/paddle/fluid/framework/var_type_inference.h b/paddle/fluid/framework/var_type_inference.h index f3035cd712..64236b78d2 100644 --- a/paddle/fluid/framework/var_type_inference.h +++ b/paddle/fluid/framework/var_type_inference.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/type_defs.h" namespace paddle { @@ -24,5 +27,27 @@ class VarTypeInference { virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0; }; +class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const final { + auto in_out_var_names = this->GetInputOutputWithSameType(); + + for (auto& i_o_n : in_out_var_names) { + auto& x_name = op_desc.Input(i_o_n.first).at(0); + auto& out_name = op_desc.Output(i_o_n.second).at(0); + + auto& x = block->FindRecursiveOrCreateVar(x_name); + auto& out = block->FindRecursiveOrCreateVar(out_name); + out.SetType(x.GetType()); + out.SetDataType(x.GetDataType()); + } + } + + protected: + virtual std::unordered_map + GetInputOutputWithSameType() const = 0; +}; + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 9ddb3a5d29..ea260a3e92 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -91,16 +91,12 @@ class ActivationOp : public framework::OperatorWithKernel { } }; -class ActivationOpInferVarType : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto x_name = op_desc.Input("X")[0]; - auto out_name = op_desc.Output("Out")[0]; - auto& x = block->FindRecursiveOrCreateVar(x_name); - auto& out = block->FindRecursiveOrCreateVar(out_name); - out.SetType(x.GetType()); - out.SetDataType(x.GetDataType()); +class ActivationOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Out"}}; } }; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 3eb4738325..cf245f5038 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -170,6 +170,15 @@ The required data format for this layer is one of the following: } }; +class BatchNormOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Y"}}; + } +}; + template class BatchNormKernel : public framework::OpKernel { @@ -525,7 +534,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, - ops::BatchNormGradMaker); + ops::BatchNormOpInferVarType, ops::BatchNormGradMaker); REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 7401f100d7..4d37074638 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -224,6 +224,15 @@ $$ )DOC"); } +class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{ + {"Input", /*->*/ "Output"}}; + } +}; + void Conv3DOpMaker::Make() { AddInput( "Input", @@ -365,6 +374,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( namespace ops = paddle::operators; REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker, + ops::ConvOpInferVarType, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad); @@ -372,7 +382,9 @@ REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad); + REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker, + ops::ConvOpInferVarType, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad); diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 66f19fe7ec..a904dd9130 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cross_entropy_op.h" +#include namespace paddle { namespace operators { @@ -179,6 +180,15 @@ or not. But the output only shares the LoD information with input X. )DOC"); } }; + +class CrossEntropyOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Y"}}; + } +}; } // namespace operators } // namespace paddle @@ -186,6 +196,7 @@ namespace ops = paddle::operators; using CPUCtx = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, + ops::CrossEntropyOpInferVarType, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel, diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index 5eb4233344..f01f67692e 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -75,16 +75,12 @@ class ElementwiseOp : public framework::OperatorWithKernel { } }; -class ElementwiseOpInferVarType : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto x_name = op_desc.Input("X")[0]; - auto out_name = op_desc.Output("Out")[0]; - auto &x = block->FindRecursiveOrCreateVar(x_name); - auto &out = block->FindRecursiveOrCreateVar(out_name); - out.SetType(x.GetType()); - out.SetDataType(x.GetDataType()); +class ElementwiseOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Out"}}; } }; diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 19426b3c20..820636defa 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mean_op.h" - +#include namespace paddle { namespace operators { @@ -42,6 +42,14 @@ Mean Operator calculates the mean of all elements in X. } }; +class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Out"}}; + } +}; + class MeanGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -50,6 +58,14 @@ class MeanGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->ShareLoD("X", framework::GradVarName("X")); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("X")->type()); + + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class MeanGradMaker : public framework::SingleGradOpDescMaker { @@ -71,7 +87,8 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker); +REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType, + ops::MeanGradMaker); REGISTER_OPERATOR(mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL( mean, ops::MeanKernel, diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index a2140ddc79..08f2949d4a 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -126,6 +126,14 @@ or not. But the output only shares the LoD information with input $X$. } }; +class MulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Out"}}; + } +}; + class MulGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -178,7 +186,8 @@ class MulOpGradMaker : public framework::SingleGradOpDescMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGradMaker); +REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType, + ops::MulOpGradMaker); REGISTER_OPERATOR(mul_grad, ops::MulGradOp); REGISTER_OP_CPU_KERNEL( mul, ops::MulKernel, diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 484cb65746..46a95350a7 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -40,7 +40,7 @@ int PoolOutputSize(int input_size, int filter_size, int padding, int stride, return output_size; } -void PoolOp::InferShape(framework::InferShapeContext *ctx) const { +void PoolOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Out(Output) of Pooling should not be null."); @@ -81,7 +81,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { } framework::OpKernelType PoolOp::GetExpectedKernelType( - const framework::ExecutionContext &ctx) const { + const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); @@ -104,7 +104,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( layout_, library_); } -void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { +void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Input(X@GRAD) should not be null."); @@ -112,7 +112,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { } framework::OpKernelType PoolOpGrad::GetExpectedKernelType( - const framework::ExecutionContext &ctx) const { + const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); @@ -262,6 +262,14 @@ Example: )DOC"); } +class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Out"}}; + } +}; + void Pool3dOpMaker::Make() { AddInput("X", "(Tensor) The input tensor of pooling operator. " @@ -372,6 +380,7 @@ Example: namespace ops = paddle::operators; REGISTER_OPERATOR(pool2d, ops::PoolOp, ops::Pool2dOpMaker, + ops::PoolOpInferVarType, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(pool2d_grad, ops::PoolOpGrad); @@ -383,6 +392,7 @@ REGISTER_OP_CPU_KERNEL( ops::PoolGradKernel); REGISTER_OPERATOR(pool3d, ops::PoolOp, ops::Pool3dOpMaker, + ops::PoolOpInferVarType, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(pool3d_grad, ops::PoolOpGrad); diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index a4bdbe6648..9e21b6c824 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -124,6 +124,14 @@ For each row $i$ and each column $j$ in the matrix, we have: } }; +class SoftmaxOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Out"}}; + } +}; + class SoftmaxOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -196,7 +204,7 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, - ops::SoftmaxOpGradMaker); + ops::SoftmaxOpInferVarType, ops::SoftmaxOpGradMaker); REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL( softmax, ops::SoftmaxKernel, -- GitLab