提交 8bb824bb 编写于 作者: T tensor-tang

refine infershape hasinput and hasoutput

上级 c4394bc5
...@@ -21,7 +21,6 @@ limitations under the License. */ ...@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/shape_runtime_infer.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -459,147 +458,184 @@ bool OpSupportGPU(const std::string& op_type) { ...@@ -459,147 +458,184 @@ bool OpSupportGPU(const std::string& op_type) {
return false; return false;
} }
bool RuntimeInferShapeContext::HasInput(const std::string& name) const { class RuntimeInferShapeContext : public InferShapeContext {
if (!op_.HasInputs(name)) { public:
return false; RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
} : op_(op), scope_(scope) {}
auto& ins = Inputs(name);
size_t length = ins.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Input %s should not have more than one inputs", name);
auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
bool RuntimeInferShapeContext::HasOutput(const std::string& name) const { bool HasInput(const std::string& name) const override {
if (!op_.HasOutputs(name)) { // has only one input
return false; const auto& ins = op_.Inputs();
} auto it = ins.find(name);
auto& outs = Outputs(name); if (it == ins.end()) {
size_t length = outs.size(); return false;
if (length == 0) { }
return false; const auto& in = it->second;
}
PADDLE_ENFORCE_EQ(length, 1UL,
"Output %s should not have more than one inputs", name);
auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
bool RuntimeInferShapeContext::HasInputs(const std::string& name) const { if (in.size() != 1 || in[0] == kEmptyVarName) {
if (!op_.HasInputs(name)) {
return false;
}
auto inputs = op_.Inputs(name);
if (inputs.empty()) {
return false;
}
for (auto& input : inputs) {
if (scope_.FindVar(input) == nullptr) {
return false; return false;
} }
return scope_.FindVar(in[0]) != nullptr;
} }
return true;
}
bool RuntimeInferShapeContext::HasOutputs(const std::string& name) const { bool HasOutput(const std::string& name) const override {
if (!op_.HasOutputs(name)) { // has only one output
return false; const auto& outs = op_.Outputs();
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
const auto& out = it->second;
if (out.size() != 1 || out[0] == kEmptyVarName) {
return false;
}
return scope_.FindVar(out[0]) != nullptr;
} }
auto outputs = op_.Outputs(name);
if (outputs.empty()) { bool HasInputs(const std::string& name) const override {
return false; if (!op_.HasInputs(name)) {
return false;
}
auto inputs = op_.Inputs(name);
if (inputs.empty()) {
return false;
}
for (auto& input : inputs) {
if (scope_.FindVar(input) == nullptr) {
return false;
}
}
return true;
} }
for (auto& output : outputs) {
if (scope_.FindVar(output) == nullptr) { bool HasOutputs(const std::string& name) const override {
if (!op_.HasOutputs(name)) {
return false;
}
auto outputs = op_.Outputs(name);
if (outputs.empty()) {
return false; return false;
} }
for (auto& output : outputs) {
if (scope_.FindVar(output) == nullptr) {
return false;
}
}
return true;
} }
return true;
}
void RuntimeInferShapeContext::ShareLoD(const std::string& in, AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::string& out, size_t i,
size_t j) const { const std::vector<std::string>& Inputs(
PADDLE_ENFORCE_LT(i, Inputs(in).size()); const std::string& name) const override {
PADDLE_ENFORCE_LT(j, Outputs(out).size()); return op_.Inputs(name);
Variable* in_var = scope_.FindVar(Inputs(in)[i]); }
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
if (!in_var->IsType<LoDTensor>()) return; const std::vector<std::string>& Outputs(
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(), const std::string& name) const override {
"The %d-th output of Output(%s) must be LoDTensor.", j, out); return op_.Outputs(name);
auto in_tensor = in_var->Get<LoDTensor>(); }
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod()); void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators. // TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops. // Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor? // Shall we have a better method to shared info between in/out Tensor?
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Fix me: ugly workaround below // Fix me: ugly workaround below
// Correct solution: // Correct solution:
// set_layout() should NOT be called here (i.e. ShareLoD). Instead, // set_layout() should NOT be called here (i.e. ShareLoD). Instead,
// layout of output tensor should be set "manually" in Compute() // layout of output tensor should be set "manually" in Compute()
// of each OPKernel. The reason layout should NOT be shared between // of each OPKernel. The reason layout should NOT be shared between
// input and output "automatically" (now by InferShape()->ShareLoD()) // input and output "automatically" (now by InferShape()->ShareLoD())
// is that layout transform may occur after InferShape(). // is that layout transform may occur after InferShape().
// Workaround: // Workaround:
// Skip set_layout() when input layout is kMKLDNN // Skip set_layout() when input layout is kMKLDNN
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute() // in Compute()
if (in_tensor.layout() != DataLayout::kMKLDNN) if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif #endif
out_tensor->set_layout(in_tensor.layout());
}
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_layout(in_tensor.layout()); out_tensor->set_layout(in_tensor.layout());
} }
void RuntimeInferShapeContext::ShareLayout(const std::string& in, bool IsRuntime() const override { return true; }
const std::string& out, size_t i,
size_t j) const { protected:
PADDLE_ENFORCE_LT(i, Inputs(in).size()); DDim GetDim(const std::string& name) const override {
PADDLE_ENFORCE_LT(j, Outputs(out).size()); Variable* var = scope_.FindVar(name);
Variable* in_var = scope_.FindVar(Inputs(in)[i]); PADDLE_ENFORCE_NOT_NULL(var);
Variable* out_var = scope_.FindVar(Outputs(out)[j]); if (var->IsType<LoDTensor>()) {
if (!in_var->IsType<LoDTensor>()) return; return var->Get<LoDTensor>().dims();
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(), } else if (var->IsType<SelectedRows>()) {
"The %d-th output of Output(%s) must be LoDTensor.", j, out); return var->Get<SelectedRows>().GetCompleteDims();
auto in_tensor = in_var->Get<LoDTensor>(); } else {
auto* out_tensor = out_var->GetMutable<LoDTensor>(); PADDLE_THROW(
out_tensor->set_layout(in_tensor.layout()); "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
} "type_id is %s.",
name, var->Type().name());
DDim RuntimeInferShapeContext::GetDim(const std::string& name) const { }
Variable* var = scope_.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW(
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
} }
}
void RuntimeInferShapeContext::SetDim(const std::string& name, std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
const DDim& dim) { PADDLE_THROW("Only compile time support this method");
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", name,
var->Type().name());
} }
}
void SetDim(const std::string& name, const DDim& dim) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
name, var->Type().name());
}
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
PADDLE_THROW("Only compile time support this method");
}
proto::VarType::Type GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
}
InferShapeVarPtr GetVarPtr(const std::string& name) override {
return scope_.FindVar(name);
}
private:
const OperatorBase& op_;
const Scope& scope_;
};
static void CheckTensorNANOrInf(const std::string& name, static void CheckTensorNANOrInf(const std::string& name,
const framework::Tensor& tensor) { const framework::Tensor& tensor) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h"
namespace paddle {
namespace framework {
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override;
bool HasOutput(const std::string& name) const override;
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name) const override;
const OperatorBase& OpBase() const { return op_; }
const Scope& InferScope() const { return scope_; }
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Inputs(name);
}
const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Outputs(name);
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override;
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const;
bool IsRuntime() const override { return true; }
protected:
DDim GetDim(const std::string& name) const override;
void SetDim(const std::string& name, const DDim& dim) override;
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW("Only compile time support this method");
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
PADDLE_THROW("Only compile time support this method");
}
proto::VarType::Type GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
}
InferShapeVarPtr GetVarPtr(const std::string& name) override {
return scope_.FindVar(name);
}
private:
const OperatorBase& op_;
const Scope& scope_;
};
} // namespace framework
} // namespace paddle
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/attention_lstm_op.h" #include "paddle/fluid/operators/attention_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/fusion_infershape_define.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
...@@ -24,28 +23,28 @@ namespace paddle { ...@@ -24,28 +23,28 @@ namespace paddle {
namespace operators { namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
FUSION_INFERSHAPE_INIT; PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of AttentionLSTM."); "Assert only one Input(X) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("C0"), PADDLE_ENFORCE(ctx->HasInput("C0"),
"Assert only one Input(C0) of AttentionLSTM."); "Assert only one Input(C0) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("LSTMWeight"), PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
"Assert only one Input(LSTMWeight) of AttentionLSTM."); "Assert only one Input(LSTMWeight) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("LSTMBias"), PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
"Assert only one Input(LSTMBias) of AttentionLSTM."); "Assert only one Input(LSTMBias) of AttentionLSTM.");
PADDLE_ENFORCE(fair_input("AttentionWeight"), PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Assert only one Input(AttentionWeight) of AttentionLSTM."); "Assert only one Input(AttentionWeight) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Assert only one Output(Hidden) of AttentionLSTM."); "Assert only one Output(Hidden) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("Cell"), PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Assert only one Output(Cell) of AttentionLSTM."); "Assert only one Output(Cell) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("AttentionedX"), PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
"Assert only one Output(AttentionedX) of AttentionLSTM."); "Assert only one Output(AttentionedX) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("AttentionFCOut"), PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
"Assert only one Output(AttentionFCOut) of AttentionLSTM."); "Assert only one Output(AttentionFCOut) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("LSTMX"), PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Assert only one Output(LSTMX) of AttentionLSTM."); "Assert only one Output(LSTMX) of AttentionLSTM.");
PADDLE_ENFORCE(fair_output("LSTMOUT"), PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Assert only one Output(LSTMOUT) of AttentionLSTM."); "Assert only one Output(LSTMOUT) of AttentionLSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
...@@ -66,7 +65,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -66,7 +65,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (fair_input("H0")) { if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) " "The dimension of Input(H0) and Input(C0) "
...@@ -80,7 +79,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -80,7 +79,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (fair_input("AttentionBias")) { if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias"); auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2."); "Input(AttentionBias)'s rank must be 2.");
...@@ -90,7 +89,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -90,7 +89,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"AttentionBias shapes must be 1 * 1."); "AttentionBias shapes must be 1 * 1.");
} }
if (fair_input("AttentionScalar")) { if (ctx->HasInput("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar"); auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2."); "Input(AttentionScalar)'s rank must be 2.");
...@@ -98,10 +97,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -98,10 +97,10 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
} }
if (fair_input("AttentionScalarBias")) { if (ctx->HasInput("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias"); auto dims = ctx->GetInputDim("AttentionScalarBias");
PADDLE_ENFORCE( PADDLE_ENFORCE(
fair_input("AttentionScalar"), ctx->HasInput("AttentionScalar"),
"AttentionScalar should not be null when have AttentionScalarBias."); "AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2."); "Input(AttentionScalarBias)'s rank must be 2.");
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h" #include "paddle/fluid/operators/fusion_gru_op.h"
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/fusion_infershape_define.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
...@@ -26,14 +25,13 @@ namespace paddle { ...@@ -26,14 +25,13 @@ namespace paddle {
namespace operators { namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
FUSION_INFERSHAPE_INIT; PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of GRU.");
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of GRU."); PADDLE_ENFORCE(ctx->HasInput("WeightX"),
PADDLE_ENFORCE(fair_input("WeightX"),
"Assert only one Input(WeightX) of GRU."); "Assert only one Input(WeightX) of GRU.");
PADDLE_ENFORCE(fair_input("WeightH"), PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Assert only one Input(WeightH) of GRU."); "Assert only one Input(WeightH) of GRU.");
PADDLE_ENFORCE(fair_output("XX"), "Assert only one Output(XX) of GRU."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of GRU.");
PADDLE_ENFORCE(fair_output("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Assert only one Output(Hidden) of GRU."); "Assert only one Output(Hidden) of GRU.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
...@@ -60,12 +58,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -60,12 +58,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"should be 3 * %d.", "should be 3 * %d.",
frame_size); frame_size);
if (fair_input("H0")) { if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0"); auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size, PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size."); "The width of H0 must be equal to frame_size.");
} }
if (fair_input("Bias")) { if (ctx->HasInput("Bias")) {
auto b_dims = ctx->GetInputDim("Bias"); auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, PADDLE_ENFORCE_EQ(b_dims[0], 1,
...@@ -81,11 +79,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -81,11 +79,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(fair_output("ReorderedH0"), PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Assert only one Output(ReorderedH0) of GRU."); "Assert only one Output(ReorderedH0) of GRU.");
PADDLE_ENFORCE(fair_output("BatchedInput"), PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Assert only one Output(BatchedInput) of GRU."); "Assert only one Output(BatchedInput) of GRU.");
PADDLE_ENFORCE(fair_output("BatchedOut"), PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Assert only one Output(BatchedOut) of GRU."); "Assert only one Output(BatchedOut) of GRU.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims); ctx->SetOutputDim("BatchedOut", out_dims);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_
#define PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_
#include <string>
#include "paddle/fluid/framework/shape_runtime_infer.h"
namespace paddle {
namespace operators {
#define FUSION_INFERSHAPE_INIT \
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx); \
if (runtime_ctx == nullptr) { \
LOG(FATAL) << "Should have runtime infer context"; \
} \
const auto& ins = runtime_ctx->OpBase().Inputs(); \
const auto& outs = runtime_ctx->OpBase().Outputs(); \
const auto& scope = runtime_ctx->InferScope(); \
const auto ins_end = ins.end(); \
const auto outs_end = outs.end(); \
auto fair_input = [&](const std::string& name) -> bool { \
auto it = ins.find(name); \
if (it == ins_end) { \
return false; \
} \
const auto& in = it->second; \
if (in.size() != 1 || in[0] == framework::kEmptyVarName) { \
return false; \
} \
return scope.FindVar(in[0]) != nullptr; \
}; \
auto fair_output = [&](const std::string& name) -> bool { \
auto it = outs.find(name); \
if (it == outs_end) { \
return false; \
} \
const auto& out = it->second; \
if (out.size() != 1 || out[0] == framework::kEmptyVarName) { \
return false; \
} \
return scope.FindVar(out[0]) != nullptr; \
}
} // namespace operators
} // namespace paddle
#endif // PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_lstm_op.h" #include "paddle/fluid/operators/fusion_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/fusion_infershape_define.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
...@@ -25,23 +24,23 @@ namespace paddle { ...@@ -25,23 +24,23 @@ namespace paddle {
namespace operators { namespace operators {
void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
FUSION_INFERSHAPE_INIT; PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of LSTM.");
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of LSTM."); PADDLE_ENFORCE(ctx->HasInput("WeightX"),
PADDLE_ENFORCE(fair_input("WeightX"),
"Assert only one Input(WeightX) of LSTM."); "Assert only one Input(WeightX) of LSTM.");
PADDLE_ENFORCE(fair_input("WeightH"), PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Assert only one Input(WeightH) of LSTM."); "Assert only one Input(WeightH) of LSTM.");
PADDLE_ENFORCE(fair_input("Bias"), "Assert only one Input(Bias) of LSTM."); PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM.");
PADDLE_ENFORCE(fair_output("XX"), "Assert only one Output(XX) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM.");
PADDLE_ENFORCE(fair_output("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Assert only one Output(Hidden) of LSTM."); "Assert only one Output(Hidden) of LSTM.");
PADDLE_ENFORCE(fair_output("Cell"), "Assert only one Output(Cell) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Assert only one Output(Cell) of LSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
if (fair_input("H0")) { if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(fair_input("C0"), PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(Cell) and Input(Hidden) of LSTM should not " "Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time."); "be null at the same time.");
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
...@@ -93,15 +92,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -93,15 +92,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(fair_output("BatchedInput"), PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Assert only one Output(BatchedInput) of LSTM."); "Assert only one Output(BatchedInput) of LSTM.");
PADDLE_ENFORCE(fair_output("BatchedHidden"), PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Assert only one Output(BatchedHidden) of LSTM."); "Assert only one Output(BatchedHidden) of LSTM.");
PADDLE_ENFORCE(fair_output("BatchedCell"), PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
"Assert only one Output(BatchedCell) of LSTM."); "Assert only one Output(BatchedCell) of LSTM.");
PADDLE_ENFORCE(fair_output("ReorderedH0"), PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Assert only one Output(ReorderedH0) of LSTM"); "Assert only one Output(ReorderedH0) of LSTM");
PADDLE_ENFORCE(fair_output("ReorderedC0"), PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
"Assert only one Output(ReorderedC0) of LSTM."); "Assert only one Output(ReorderedC0) of LSTM.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedHidden", out_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册