未验证 提交 6c6e6385 编写于 作者: C chengduo 提交者: GitHub

Add InferVarType for some op (#14201)

* add_infer_var_type
test=develop

* InferVarTypeHelper-> VarTypeInferenceHelper
test=develop

* PassInputTypeAndDTypeOnOutput
 test=develop

* follow comment
test=develop
上级 0b388226
...@@ -259,6 +259,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -259,6 +259,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
if (row_size >= 0) { if (row_size >= 0) {
ss << "[row_size=" << row_size << "]"; ss << "[row_size=" << row_size << "]";
} }
std::string dtype = GetDtype(*scope, output.second[i]);
ss << ":" << dtype;
ss << "[" << GetDims(*scope, var_name, true) << "]"; ss << "[" << GetDims(*scope, var_name, true) << "]";
ss << "(" << GetLoD(*scope, var_name) << ")"; ss << "(" << GetLoD(*scope, var_name) << ")";
} }
......
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
namespace paddle { namespace paddle {
...@@ -24,5 +27,27 @@ class VarTypeInference { ...@@ -24,5 +27,27 @@ class VarTypeInference {
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0; virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
}; };
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const final {
auto in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) {
auto& x_name = op_desc.Input(i_o_n.first).at(0);
auto& out_name = op_desc.Output(i_o_n.second).at(0);
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
}
}
protected:
virtual std::unordered_map<std::string, std::string>
GetInputOutputWithSameType() const = 0;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -91,16 +91,12 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -91,16 +91,12 @@ class ActivationOp : public framework::OperatorWithKernel {
} }
}; };
class ActivationOpInferVarType : public framework::VarTypeInference { class ActivationOpInferVarType
public: : public framework::PassInDtypeAndVarTypeToOutput {
void operator()(const framework::OpDesc& op_desc, protected:
framework::BlockDesc* block) const override { std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
auto x_name = op_desc.Input("X")[0]; const override {
auto out_name = op_desc.Output("Out")[0]; return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
} }
}; };
......
...@@ -170,6 +170,15 @@ The required data format for this layer is one of the following: ...@@ -170,6 +170,15 @@ The required data format for this layer is one of the following:
} }
}; };
class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
template <typename T> template <typename T>
class BatchNormKernel<platform::CPUDeviceContext, T> class BatchNormKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -525,7 +534,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker { ...@@ -525,7 +534,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormGradMaker); ops::BatchNormOpInferVarType, ops::BatchNormGradMaker);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -224,6 +224,15 @@ $$ ...@@ -224,6 +224,15 @@ $$
)DOC"); )DOC");
} }
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{
{"Input", /*->*/ "Output"}};
}
};
void Conv3DOpMaker::Make() { void Conv3DOpMaker::Make() {
AddInput( AddInput(
"Input", "Input",
...@@ -365,6 +374,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -365,6 +374,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker, REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
...@@ -372,7 +382,9 @@ REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad); ...@@ -372,7 +382,9 @@ REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker, REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker, REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
ops::ConvOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h" #include "paddle/fluid/operators/cross_entropy_op.h"
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -179,6 +180,15 @@ or not. But the output only shares the LoD information with input X. ...@@ -179,6 +180,15 @@ or not. But the output only shares the LoD information with input X.
)DOC"); )DOC");
} }
}; };
class CrossEntropyOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -186,6 +196,7 @@ namespace ops = paddle::operators; ...@@ -186,6 +196,7 @@ namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext; using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
ops::CrossEntropyOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp); REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>, REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
......
...@@ -75,16 +75,12 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -75,16 +75,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
} }
}; };
class ElementwiseOpInferVarType : public framework::VarTypeInference { class ElementwiseOpInferVarType
public: : public framework::PassInDtypeAndVarTypeToOutput {
void operator()(const framework::OpDesc &op_desc, protected:
framework::BlockDesc *block) const override { std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
auto x_name = op_desc.Input("X")[0]; const override {
auto out_name = op_desc.Output("Out")[0]; return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
auto &x = block->FindRecursiveOrCreateVar(x_name);
auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
} }
}; };
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mean_op.h" #include "paddle/fluid/operators/mean_op.h"
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,6 +42,14 @@ Mean Operator calculates the mean of all elements in X. ...@@ -42,6 +42,14 @@ Mean Operator calculates the mean of all elements in X.
} }
}; };
class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class MeanGradOp : public framework::OperatorWithKernel { class MeanGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -50,6 +58,14 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -50,6 +58,14 @@ class MeanGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareLoD("X", framework::GradVarName("X"));
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class MeanGradMaker : public framework::SingleGradOpDescMaker { class MeanGradMaker : public framework::SingleGradOpDescMaker {
...@@ -71,7 +87,8 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker { ...@@ -71,7 +87,8 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker); REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp); REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>, mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -126,6 +126,14 @@ or not. But the output only shares the LoD information with input $X$. ...@@ -126,6 +126,14 @@ or not. But the output only shares the LoD information with input $X$.
} }
}; };
class MulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class MulGradOp : public framework::OperatorWithKernel { class MulGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -178,7 +186,8 @@ class MulOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -178,7 +186,8 @@ class MulOpGradMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGradMaker); REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType,
ops::MulOpGradMaker);
REGISTER_OPERATOR(mul_grad, ops::MulGradOp); REGISTER_OPERATOR(mul_grad, ops::MulGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>, mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -40,7 +40,7 @@ int PoolOutputSize(int input_size, int filter_size, int padding, int stride, ...@@ -40,7 +40,7 @@ int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
return output_size; return output_size;
} }
void PoolOp::InferShape(framework::InferShapeContext *ctx) const { void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null."); "Out(Output) of Pooling should not be null.");
...@@ -81,7 +81,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -81,7 +81,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
} }
framework::OpKernelType PoolOp::GetExpectedKernelType( framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
...@@ -104,7 +104,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( ...@@ -104,7 +104,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
layout_, library_); layout_, library_);
} }
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null."); "Input(X@GRAD) should not be null.");
...@@ -112,7 +112,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { ...@@ -112,7 +112,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
} }
framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
...@@ -262,6 +262,14 @@ Example: ...@@ -262,6 +262,14 @@ Example:
)DOC"); )DOC");
} }
class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
void Pool3dOpMaker::Make() { void Pool3dOpMaker::Make() {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of pooling operator. " "(Tensor) The input tensor of pooling operator. "
...@@ -372,6 +380,7 @@ Example: ...@@ -372,6 +380,7 @@ Example:
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(pool2d, ops::PoolOp, ops::Pool2dOpMaker, REGISTER_OPERATOR(pool2d, ops::PoolOp, ops::Pool2dOpMaker,
ops::PoolOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(pool2d_grad, ops::PoolOpGrad); REGISTER_OPERATOR(pool2d_grad, ops::PoolOpGrad);
...@@ -383,6 +392,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -383,6 +392,7 @@ REGISTER_OP_CPU_KERNEL(
ops::PoolGradKernel<paddle::platform::CPUDeviceContext, double>); ops::PoolGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(pool3d, ops::PoolOp, ops::Pool3dOpMaker, REGISTER_OPERATOR(pool3d, ops::PoolOp, ops::Pool3dOpMaker,
ops::PoolOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(pool3d_grad, ops::PoolOpGrad); REGISTER_OPERATOR(pool3d_grad, ops::PoolOpGrad);
......
...@@ -124,6 +124,14 @@ For each row $i$ and each column $j$ in the matrix, we have: ...@@ -124,6 +124,14 @@ For each row $i$ and each column $j$ in the matrix, we have:
} }
}; };
class SoftmaxOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class SoftmaxOpGrad : public framework::OperatorWithKernel { class SoftmaxOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -196,7 +204,7 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -196,7 +204,7 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
ops::SoftmaxOpGradMaker); ops::SoftmaxOpInferVarType, ops::SoftmaxOpGradMaker);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad); REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>, softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册