提交 e0436ad8 编写于 作者: T tensor-tang

refine fusion lstm infershape

上级 94b66bdb
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/shape_runtime_infer.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -458,12 +459,7 @@ bool OpSupportGPU(const std::string& op_type) { ...@@ -458,12 +459,7 @@ bool OpSupportGPU(const std::string& op_type) {
return false; return false;
} }
class RuntimeInferShapeContext : public InferShapeContext { bool RuntimeInferShapeContext::HasInput(const std::string& name) const {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override {
if (!op_.HasInputs(name)) { if (!op_.HasInputs(name)) {
return false; return false;
} }
...@@ -477,9 +473,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -477,9 +473,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
auto ipt = ins[0]; auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
} }
bool HasOutput(const std::string& name) const override { bool RuntimeInferShapeContext::HasOutput(const std::string& name) const {
if (!op_.HasOutputs(name)) { if (!op_.HasOutputs(name)) {
return false; return false;
} }
...@@ -493,9 +489,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -493,9 +489,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
auto ipt = outs[0]; auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
} }
bool HasInputs(const std::string& name) const override { bool RuntimeInferShapeContext::HasInputs(const std::string& name) const {
if (!op_.HasInputs(name)) { if (!op_.HasInputs(name)) {
return false; return false;
} }
...@@ -509,9 +505,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -509,9 +505,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
} }
return true; return true;
} }
bool HasOutputs(const std::string& name) const override { bool RuntimeInferShapeContext::HasOutputs(const std::string& name) const {
if (!op_.HasOutputs(name)) { if (!op_.HasOutputs(name)) {
return false; return false;
} }
...@@ -525,22 +521,11 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -525,22 +521,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
} }
return true; return true;
} }
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Inputs(name);
}
const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Outputs(name);
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void RuntimeInferShapeContext::ShareLoD(const std::string& in,
size_t j = 0) const override { const std::string& out, size_t i,
size_t j) const {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]); Variable* in_var = scope_.FindVar(Inputs(in)[i]);
...@@ -571,10 +556,11 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -571,10 +556,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
if (in_tensor.layout() != DataLayout::kMKLDNN) if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif #endif
out_tensor->set_layout(in_tensor.layout()); out_tensor->set_layout(in_tensor.layout());
} }
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0, void RuntimeInferShapeContext::ShareLayout(const std::string& in,
size_t j = 0) const { const std::string& out, size_t i,
size_t j) const {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]); Variable* in_var = scope_.FindVar(Inputs(in)[i]);
...@@ -585,12 +571,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -585,12 +571,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
auto in_tensor = in_var->Get<LoDTensor>(); auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>(); auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_layout(in_tensor.layout()); out_tensor->set_layout(in_tensor.layout());
} }
bool IsRuntime() const override { return true; }
protected: DDim RuntimeInferShapeContext::GetDim(const std::string& name) const {
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name); Variable* var = scope_.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
...@@ -603,42 +586,20 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -603,42 +586,20 @@ class RuntimeInferShapeContext : public InferShapeContext {
"type_id is %s.", "type_id is %s.",
name, var->Type().name()); name, var->Type().name());
} }
} }
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW("Only compile time support this method");
}
void SetDim(const std::string& name, const DDim& dim) override { void RuntimeInferShapeContext::SetDim(const std::string& name,
const DDim& dim) {
Variable* var = scope_.FindVar(name); Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim); var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]); var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else { } else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", name,
name, var->Type().name()); var->Type().name());
}
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
PADDLE_THROW("Only compile time support this method");
}
proto::VarType::Type GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
}
InferShapeVarPtr GetVarPtr(const std::string& name) override {
return scope_.FindVar(name);
} }
}
private:
const OperatorBase& op_;
const Scope& scope_;
};
static void CheckTensorNANOrInf(const std::string& name, static void CheckTensorNANOrInf(const std::string& name,
const framework::Tensor& tensor) { const framework::Tensor& tensor) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h"
namespace paddle {
namespace framework {
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override;
bool HasOutput(const std::string& name) const override;
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name) const override;
const OperatorBase& OpBase() const { return op_; }
const Scope& InferScope() const { return scope_; }
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Inputs(name);
}
const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Outputs(name);
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override;
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const;
bool IsRuntime() const override { return true; }
protected:
DDim GetDim(const std::string& name) const override;
void SetDim(const std::string& name, const DDim& dim) override;
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW("Only compile time support this method");
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
PADDLE_THROW("Only compile time support this method");
}
proto::VarType::Type GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
}
InferShapeVarPtr GetVarPtr(const std::string& name) override {
return scope_.FindVar(name);
}
private:
const OperatorBase& op_;
const Scope& scope_;
};
} // namespace framework
} // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_lstm_op.h" #include "paddle/fluid/operators/fusion_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/shape_runtime_infer.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
...@@ -24,26 +25,54 @@ namespace paddle { ...@@ -24,26 +25,54 @@ namespace paddle {
namespace operators { namespace operators {
void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
PADDLE_ENFORCE(ctx->HasInput("WeightX"), if (runtime_ctx == nullptr) {
"Input(WeightX) of LSTM should not be null."); LOG(FATAL) << "Should have runtime infer context";
PADDLE_ENFORCE(ctx->HasInput("WeightH"), }
"Input(WeightH) of LSTM should not be null."); const auto& ins = runtime_ctx->OpBase().Inputs();
PADDLE_ENFORCE(ctx->HasInput("Bias"), const auto& outs = runtime_ctx->OpBase().Outputs();
"Input(Bias) of LSTM should not be null."); const auto& scope = runtime_ctx->InferScope();
const auto ins_end = ins.end();
PADDLE_ENFORCE(ctx->HasOutput("XX"), const auto outs_end = outs.end();
"Output(XX) of LSTM should not be null."); auto fair_input = [&](const std::string& name) -> bool {
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), auto it = ins.find(name);
"Output(Hidden) of LSTM should not be null."); if (it == ins_end) {
PADDLE_ENFORCE(ctx->HasOutput("Cell"), return false;
"Output(Cell) of LSTM should not be null."); }
const auto& in = it->second;
if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
return false;
}
return scope.FindVar(in[0]) != nullptr;
};
auto fair_output = [&](const std::string& name) -> bool {
auto it = outs.find(name);
if (it == outs_end) {
return false;
}
const auto& out = it->second;
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
return false;
}
return scope.FindVar(out[0]) != nullptr;
};
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of LSTM.");
PADDLE_ENFORCE(fair_input("WeightX"),
"Assert only one Input(WeightX) of LSTM.");
PADDLE_ENFORCE(fair_input("WeightH"),
"Assert only one Input(WeightH) of LSTM.");
PADDLE_ENFORCE(fair_input("Bias"), "Assert only one Input(Bias) of LSTM.");
PADDLE_ENFORCE(fair_output("XX"), "Assert only one Output(XX) of LSTM.");
PADDLE_ENFORCE(fair_output("Hidden"),
"Assert only one Output(Hidden) of LSTM.");
PADDLE_ENFORCE(fair_output("Cell"), "Assert only one Output(Cell) of LSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
if (ctx->HasInput("H0")) { if (fair_input("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"), PADDLE_ENFORCE(fair_input("C0"),
"Input(Cell) and Input(Hidden) of LSTM should not " "Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time."); "be null at the same time.");
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
...@@ -95,16 +124,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -95,16 +124,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), PADDLE_ENFORCE(fair_output("BatchedInput"),
"Output(BatchedInput) of LSTM should not be null."); "Assert only one Output(BatchedInput) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), PADDLE_ENFORCE(fair_output("BatchedHidden"),
"Output(BatchedHidden) of LSTM should not be null."); "Assert only one Output(BatchedHidden) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), PADDLE_ENFORCE(fair_output("BatchedCell"),
"Output(BatchedCell) of LSTM should not be null."); "Assert only one Output(BatchedCell) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), PADDLE_ENFORCE(fair_output("ReorderedH0"),
"Output(ReorderedH0) of LSTM should not be null."); "Assert only one Output(ReorderedH0) of LSTM");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), PADDLE_ENFORCE(fair_output("ReorderedC0"),
"Output(ReorderedC0) of LSTM should not be null."); "Assert only one Output(ReorderedC0) of LSTM.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims); ctx->SetOutputDim("BatchedCell", out_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册