未验证 提交 6c6e6385 编写于 作者: C chengduo 提交者: GitHub

Add InferVarType for some op (#14201)

* add_infer_var_type
test=develop

* InferVarTypeHelper-> VarTypeInferenceHelper
test=develop

* PassInputTypeAndDTypeOnOutput
 test=develop

* follow comment
test=develop
上级 0b388226
......@@ -259,6 +259,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
std::string dtype = GetDtype(*scope, output.second[i]);
ss << ":" << dtype;
ss << "[" << GetDims(*scope, var_name, true) << "]";
ss << "(" << GetLoD(*scope, var_name) << ")";
}
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
namespace paddle {
......@@ -24,5 +27,27 @@ class VarTypeInference {
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
};
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const final {
auto in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) {
auto& x_name = op_desc.Input(i_o_n.first).at(0);
auto& out_name = op_desc.Output(i_o_n.second).at(0);
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
}
}
protected:
virtual std::unordered_map<std::string, std::string>
GetInputOutputWithSameType() const = 0;
};
} // namespace framework
} // namespace paddle
......@@ -91,16 +91,12 @@ class ActivationOp : public framework::OperatorWithKernel {
}
};
class ActivationOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
class ActivationOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
......
......@@ -170,6 +170,15 @@ The required data format for this layer is one of the following:
}
};
class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
template <typename T>
class BatchNormKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
......@@ -525,7 +534,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormGradMaker);
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL(
......
......@@ -224,6 +224,15 @@ $$
)DOC");
}
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{
{"Input", /*->*/ "Output"}};
}
};
void Conv3DOpMaker::Make() {
AddInput(
"Input",
......@@ -365,6 +374,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
......@@ -372,7 +382,9 @@ REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
ops::ConvOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -179,6 +180,15 @@ or not. But the output only shares the LoD information with input X.
)DOC");
}
};
class CrossEntropyOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
} // namespace operators
} // namespace paddle
......@@ -186,6 +196,7 @@ namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
ops::CrossEntropyOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
......
......@@ -75,16 +75,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
}
};
class ElementwiseOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
auto &x = block->FindRecursiveOrCreateVar(x_name);
auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
class ElementwiseOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mean_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -42,6 +42,14 @@ Mean Operator calculates the mean of all elements in X.
}
};
class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class MeanGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -50,6 +58,14 @@ class MeanGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class MeanGradMaker : public framework::SingleGradOpDescMaker {
......@@ -71,7 +87,8 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker);
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -126,6 +126,14 @@ or not. But the output only shares the LoD information with input $X$.
}
};
class MulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class MulGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -178,7 +186,8 @@ class MulOpGradMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGradMaker);
REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType,
ops::MulOpGradMaker);
REGISTER_OPERATOR(mul_grad, ops::MulGradOp);
REGISTER_OP_CPU_KERNEL(
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -40,7 +40,7 @@ int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
return output_size;
}
void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null.");
......@@ -81,7 +81,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
}
framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
......@@ -104,7 +104,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
layout_, library_);
}
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
......@@ -112,7 +112,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
}
framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
......@@ -262,6 +262,14 @@ Example:
)DOC");
}
class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
void Pool3dOpMaker::Make() {
AddInput("X",
"(Tensor) The input tensor of pooling operator. "
......@@ -372,6 +380,7 @@ Example:
namespace ops = paddle::operators;
REGISTER_OPERATOR(pool2d, ops::PoolOp, ops::Pool2dOpMaker,
ops::PoolOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(pool2d_grad, ops::PoolOpGrad);
......@@ -383,6 +392,7 @@ REGISTER_OP_CPU_KERNEL(
ops::PoolGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(pool3d, ops::PoolOp, ops::Pool3dOpMaker,
ops::PoolOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(pool3d_grad, ops::PoolOpGrad);
......
......@@ -124,6 +124,14 @@ For each row $i$ and each column $j$ in the matrix, we have:
}
};
class SoftmaxOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class SoftmaxOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -196,7 +204,7 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
ops::SoftmaxOpGradMaker);
ops::SoftmaxOpInferVarType, ops::SoftmaxOpGradMaker);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册