提交 bbdac7f7 编写于 作者: Y Yu Yang

Polish OpWithKernel

* Chage `IndicateDataType` to `GetKernelType`. Make it easier to
  understand.
* Change `OpKernelKey` to `OpKernelType`
* Make operator developers can customize which kernel the operator will
  use in runtime.
上级 f74fb790
......@@ -55,6 +55,6 @@ After float16 class is available, some of the future items are below:
- Update pybind/tensor_py.h to bind c++ float16 with numpy float16.
- Modify `IndicateDataType()` method in `framework/operator.h` to make it compatible with float16.
- Modify `GetKernelType()` method in `framework/operator.h` to make it compatible with float16.
- Create a type-casting operator that can convert the data type in tensor between float16 and other types.
......@@ -92,8 +92,7 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
void operator()(const char* op_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
PlaceType());
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType());
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
......
......@@ -254,8 +254,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
return res;
}
std::ostream& operator<<(std::ostream& os,
const OperatorWithKernel::OpKernelKey& kernel_key) {
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) {
os << "place[" << kernel_key.place_ << "]:data_type[" << kernel_key.data_type_
<< "]";
return os;
......@@ -432,7 +431,7 @@ void OperatorWithKernel::Run(const Scope& scope,
// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
auto kernel_key = GetKernelType(ctx);
auto kernel_iter = kernels.find(kernel_key);
if (kernel_iter == kernels.end()) {
......@@ -444,6 +443,38 @@ void OperatorWithKernel::Run(const Scope& scope,
// throws errors if have.
dev_ctx.Finish();
}
OpKernelType OperatorWithKernel::GetKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.device_context());
}
DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) {
auto* var = scope.FindVar(ipt_name);
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same.", Type());
data_type = tmp;
}
}
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
}
} // namespace framework
} // namespace paddle
......@@ -345,27 +345,10 @@ class OpKernel : public OpKernelBase {
using ELEMENT_TYPE = T;
};
class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
platform::Place place_;
DataType data_type_;
OpKernelKey(DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}
OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelKey& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};
struct OpKernelHash {
struct OpKernelType {
struct Hash {
std::hash<int> hash_;
size_t operator()(const OpKernelKey& key) const {
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
......@@ -374,9 +357,26 @@ class OperatorWithKernel : public OperatorBase {
}
};
platform::Place place_;
DataType data_type_;
OpKernelType(DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}
OpKernelType(DataType data_type, const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};
class OperatorWithKernel : public OperatorBase {
public:
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
OpKernelHash>;
std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
OpKernelType::Hash>;
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
......@@ -404,40 +404,15 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) {
auto* var = scope.FindVar(ipt_name);
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same.",
Type());
data_type = tmp;
}
}
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
}
DataType IndicateDataType(const ExecutionContext& ctx) const;
};
std::ostream& operator<<(std::ostream& os,
const OperatorWithKernel::OpKernelKey& kernel_key);
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);
extern bool OpSupportGPU(const std::string& op_type);
......
......@@ -114,8 +114,8 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
DataType IndicateDataType(const ExecutionContext& ctx) const override {
return DataType::FP32;
OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(DataType::FP32, ctx.device_context());
}
};
......
......@@ -47,10 +47,11 @@ class AccuracyOp : public framework::OperatorWithKernel {
}
protected:
// IndicateDataType
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Out")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
ctx.device_context());
}
};
......
......@@ -39,10 +39,11 @@ class AucOp : public framework::OperatorWithKernel {
}
protected:
// IndicateDataType
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Out")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
ctx.device_context());
}
};
......
......@@ -303,7 +303,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
framework::DataType IndicateDataType(
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
......@@ -318,7 +319,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::ToDataType(t->type());
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.device_context());
}
};
......
......@@ -120,9 +120,11 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
ctx.device_context());
}
};
} // namespace operators
......
......@@ -51,9 +51,11 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
......@@ -98,9 +100,11 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
......
......@@ -49,9 +49,11 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
ctx.device_context());
}
};
......
......@@ -33,11 +33,12 @@ class FillConstantOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
int data_type = ctx.Attr<int>("data_type");
VLOG(10) << " FillConstant data_type = " << data_type;
return static_cast<framework::DataType>(data_type);
return framework::OpKernelType(static_cast<framework::DataType>(data_type),
ctx.device_context());
}
};
......
......@@ -40,9 +40,11 @@ class GatherOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
......@@ -55,9 +57,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
......
......@@ -57,9 +57,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
ctx.device_context());
}
};
......
......@@ -183,9 +183,11 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of linear_chain_crf
// is determined by its input "Emission".
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
ctx.device_context());
}
};
......@@ -240,10 +242,13 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of output of the linear_chain_crf_grad
// operator is determined by its input: gradients of LogLikelihood.
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))->type());
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))
->type()),
ctx.device_context());
}
};
......
......@@ -41,9 +41,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
ctx.device_context());
}
};
......@@ -97,9 +99,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
ctx.device_context());
}
};
......
......@@ -84,10 +84,11 @@ class LSTMOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
}
};
......@@ -245,10 +246,11 @@ class LSTMGradOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
}
};
......
......@@ -51,9 +51,11 @@ class MultiplexOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type());
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
ctx.device_context());
}
};
......@@ -107,9 +109,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type());
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
ctx.device_context());
}
};
......
......@@ -85,9 +85,11 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Score")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Score")->type()),
ctx.device_context());
}
};
......
......@@ -80,9 +80,11 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()),
ctx.device_context());
}
};
......
......@@ -49,9 +49,11 @@ class ScatterOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Ref")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
ctx.device_context());
}
};
......@@ -66,9 +68,11 @@ class ScatterGradOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Ref")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
ctx.device_context());
}
};
......
......@@ -107,9 +107,11 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
......
......@@ -121,9 +121,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Logits")->type());
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
ctx.device_context());
}
};
......@@ -160,10 +162,12 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<Tensor>(framework::GradVarName("Loss"))->type());
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<Tensor>(framework::GradVarName("Loss"))->type()),
ctx.device_context());
}
};
......
......@@ -47,20 +47,24 @@ class SumOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X");
if (x_vars[0]->IsType<framework::LoDTensor>()) {
return framework::ToDataType(
x_vars[0]->Get<framework::LoDTensor>().type());
return framework::OpKernelType(
framework::ToDataType(x_vars[0]->Get<framework::LoDTensor>().type()),
ctx.device_context());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
return framework::ToDataType(
x_vars[0]->Get<framework::SelectedRows>().value().type());
return framework::OpKernelType(
framework::ToDataType(
x_vars[0]->Get<framework::SelectedRows>().value().type()),
ctx.device_context());
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
auto& array = x_vars[0]->Get<framework::LoDTensorArray>();
for (auto& each : array) {
if (each.numel() != 0) {
return framework::ToDataType(each.type());
return framework::OpKernelType(framework::ToDataType(each.type()),
ctx.device_context());
}
}
}
......
......@@ -63,9 +63,11 @@ class UniformRandomOp : public framework::OperatorWithKernel {
}
protected:
framework::DataType IndicateDataType(
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
ctx.device_context());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册