提交 a4c528a3 编写于 作者: B baojun 提交者: tensor-tang

[NGraph] some ngraph updates to enable bert (#17739)

* delay infershape test=develop

* fall back subblock to paddle test=develop

* fix edge cases test=develop

* remove output duplicates test=develop

* handle reshape2_grad infershape test=develop
上级 3d3f5506
...@@ -334,7 +334,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( ...@@ -334,7 +334,7 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) { if (FLAGS_use_ngraph && ctx->block_id_ == 0) {
paddle::operators::NgraphEngine::FuseNgraphOps( paddle::operators::NgraphEngine::FuseNgraphOps(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_); ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
} }
......
...@@ -36,14 +36,14 @@ bool NgraphBridge::isRegister(const std::string& str) { ...@@ -36,14 +36,14 @@ bool NgraphBridge::isRegister(const std::string& str) {
bool NgraphBridge::isSupported( bool NgraphBridge::isSupported(
const std::unique_ptr<framework::OperatorBase>& op) { const std::unique_ptr<framework::OperatorBase>& op) {
static std::unordered_set<std::string> skip_op_list{"reshape", "reshape2", static std::unordered_set<std::string> skip_op_list{
"lookup_table"}; "reshape", "reshape2", "lookup_table", "lookup_table_grad"};
bool result = true; bool result = true;
auto& op_type = op->Type(); auto& op_type = op->Type();
auto op_attrs = paddle::framework::AttrReader(op->Attrs()); auto op_attrs = paddle::framework::AttrReader(op->Attrs());
if (!isRegister(op_type)) { if (!isRegister(op_type)) {
if (skip_op_list.count(op_type)) { if (skip_op_list.count(op_type)) {
if (op_type == "lookup_table") { if (op_type == "lookup_table" || op_type == "lookup_table_grad") {
if (op_attrs.Get<bool>("is_sparse") || if (op_attrs.Get<bool>("is_sparse") ||
(op_attrs.Get<int64_t>("padding_idx") != kNoPadding)) { (op_attrs.Get<int64_t>("padding_idx") != kNoPadding)) {
result = false; result = false;
......
...@@ -38,6 +38,10 @@ namespace operators { ...@@ -38,6 +38,10 @@ namespace operators {
static ngraph::Shape Ddim2Shape(const framework::DDim& dims) { static ngraph::Shape Ddim2Shape(const framework::DDim& dims) {
ngraph::Shape sp; ngraph::Shape sp;
if (dims.size() == 1 && dims[0] == 0) {
sp.emplace_back(0);
return sp;
}
for (int i = 0; i < dims.size(); ++i) { for (int i = 0; i < dims.size(); ++i) {
int k = dims[i]; int k = dims[i];
k = k == 0 ? 1 : k; k = k == 0 ? 1 : k;
...@@ -417,6 +421,15 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc, ...@@ -417,6 +421,15 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
} }
} }
} }
// remove output duplicates
std::unordered_set<std::string> var_out_set;
for (int i = static_cast<int>(var_out_.size()) - 1; i >= 0; --i) {
std::string var_name = var_out_.at(i);
if (var_out_set.count(var_name)) {
var_out_.erase(var_out_.begin() + i);
}
var_out_set.insert(var_name);
}
} }
void NgraphEngine::GetNgInputShape() { void NgraphEngine::GetNgInputShape() {
...@@ -458,16 +471,8 @@ void NgraphEngine::BuildNgNodes() { ...@@ -458,16 +471,8 @@ void NgraphEngine::BuildNgNodes() {
} }
} }
void NgraphEngine::RunInferShape() {
for (auto& op : fused_ops_) {
framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
op->RuntimeInferShape(scope_, place_, ctx);
}
}
void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) { void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) {
Prepare(ctx); Prepare(ctx);
RunInferShape();
GetNgInputShape(); GetNgInputShape();
BuildNgNodes(); BuildNgNodes();
ngraph_function_ = nullptr; ngraph_function_ = nullptr;
...@@ -626,6 +631,21 @@ void NgraphEngine::Run(const framework::Scope& scope, ...@@ -626,6 +631,21 @@ void NgraphEngine::Run(const framework::Scope& scope,
} }
} }
for (auto& op : fused_ops_) {
framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
if (op->Type() == "reshape2_grad") {
auto xshape_name = op->Inputs().at("XShape").at(0);
auto* xshape_var = scope_.FindVar(xshape_name);
auto* xshape_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*xshape_var);
auto& xshape_ddim = xshape_tensor->dims();
auto xgrad_name = op->Outputs().at(framework::GradVarName("X")).at(0);
auto* xgrad_var = scope_.FindVar(xgrad_name);
xgrad_var->GetMutable<framework::LoDTensor>()->Resize(xshape_ddim);
} else {
op->RuntimeInferShape(scope_, place_, ctx);
}
}
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out = {}; std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out = {};
for (size_t i = 0; i < p_var_out->size(); ++i) { for (size_t i = 0; i < p_var_out->size(); ++i) {
auto vo = p_var_out->at(i); auto vo = p_var_out->at(i);
......
...@@ -109,8 +109,6 @@ class NgraphEngine { ...@@ -109,8 +109,6 @@ class NgraphEngine {
void GetNgInputShape(); void GetNgInputShape();
// Call ngraph bridge to map ops // Call ngraph bridge to map ops
void BuildNgNodes(); void BuildNgNodes();
// run paddle RuntimeInferShape to get the tensor shape
void RunInferShape();
// build ngraph function call // build ngraph function call
void BuildNgFunction(const framework::ExecutionContext& ctx); void BuildNgFunction(const framework::ExecutionContext& ctx);
// Check cache for ngraph function or otherwise build the function // Check cache for ngraph function or otherwise build the function
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
...@@ -42,11 +44,11 @@ ngraph::NodeVector ElementwiseBinaryNodePrepare( ...@@ -42,11 +44,11 @@ ngraph::NodeVector ElementwiseBinaryNodePrepare(
if (lhs_shape == rhs_shape) { if (lhs_shape == rhs_shape) {
return ngraph::NodeVector{lhs, rhs}; return ngraph::NodeVector{lhs, rhs};
} }
axis = (rhs_shape.size() == 0) ? lhs_shape.size() - 1 : axis;
axis = (axis == -1 ? lhs_shape.size() - rhs_shape.size() : axis); axis = (axis == -1 ? lhs_shape.size() - rhs_shape.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < (int)(lhs_shape.size()), PADDLE_ENFORCE(axis >= 0 && axis < (int)(lhs_shape.size()),
"Axis should be in range [0, lhs_shape)"); "Axis should be in range [0, lhs_shape)");
paddle::platform::TrimTrailingSingularDims(&rhs_shape); paddle::platform::TrimTrailingSingularDims(&rhs_shape);
axis = (rhs_shape.size() == 0) ? lhs_shape.size() : axis;
int pre, n, post; int pre, n, post;
paddle::platform::GetMidDims(lhs_shape, rhs_shape, axis, &pre, &n, &post); paddle::platform::GetMidDims(lhs_shape, rhs_shape, axis, &pre, &n, &post);
......
...@@ -35,6 +35,7 @@ static void BuildMulNode( ...@@ -35,6 +35,7 @@ static void BuildMulNode(
int y_num_col_dims = op_attrs.Get<int>("y_num_col_dims"); int y_num_col_dims = op_attrs.Get<int>("y_num_col_dims");
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map); auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
int y_rank = y->get_shape().size();
auto x_reshape = x; auto x_reshape = x;
auto y_reshape = y; auto y_reshape = y;
...@@ -52,10 +53,14 @@ static void BuildMulNode( ...@@ -52,10 +53,14 @@ static void BuildMulNode(
std::shared_ptr<ngraph::Node> out = std::shared_ptr<ngraph::Node> out =
std::make_shared<ngraph::op::Dot>(x_reshape, y_reshape); std::make_shared<ngraph::op::Dot>(x_reshape, y_reshape);
auto dummy_out = paddle::platform::GetOutputNode(op, "Out", ngb_node_map); ngraph::Shape out_shape;
if (dummy_out && dummy_out->get_shape() != out->get_shape()) { for (int i = 0; i < x_num_col_dims; ++i) {
out = paddle::platform::NgReshaper(out, dummy_out->get_shape()); out_shape.push_back(x->get_shape()[i]);
} }
for (int i = y_num_col_dims; i < y_rank; ++i) {
out_shape.push_back(y->get_shape()[i]);
}
out = paddle::platform::NgReshaper(out, out_shape);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
} }
......
...@@ -60,17 +60,20 @@ void BuildPool2dNode( ...@@ -60,17 +60,20 @@ void BuildPool2dNode(
ngraph::Strides ng_strides{static_cast<size_t>(strides.at(0)), ngraph::Strides ng_strides{static_cast<size_t>(strides.at(0)),
static_cast<size_t>(strides.at(1))}; static_cast<size_t>(strides.at(1))};
auto ComputeCeiledOutput = [](size_t in, size_t k, size_t p, size_t s) { auto ComputeFlooredOutput = [](size_t in, size_t k, size_t p, size_t s) {
return (in - k + 2 * p) / s + 1; return (in - k + 2 * p) / s + 1;
}; };
auto ComputeCeiledOutput = [](size_t in, size_t k, size_t p, size_t s) {
return ceil(static_cast<float>(in - k + 2 * p) / s) + 1;
};
if (op_attrs.Get<bool>("ceil_mode")) { if (op_attrs.Get<bool>("ceil_mode")) {
auto dummy_out = paddle::platform::GetOutputNode(op, "Out", ngb_node_map);
auto dummpy_shape = dummy_out->get_shape();
for (size_t i = 0; i < ng_padding_above.size(); ++i) { for (size_t i = 0; i < ng_padding_above.size(); ++i) {
auto desired_size = ComputeCeiledOutput(x_shape[i + 2], ksize[i], auto ceiled_size = ComputeCeiledOutput(x_shape[i + 2], ksize[i],
paddings[i], strides[i]); paddings[i], strides[i]);
if (desired_size != dummpy_shape[i + 2]) { auto floored_size = ComputeFlooredOutput(x_shape[i + 2], ksize[i],
paddings[i], strides[i]);
if (ceiled_size != floored_size) {
ng_padding_above[i] += strides[i]; ng_padding_above[i] += strides[i];
} }
} }
...@@ -96,6 +99,10 @@ void BuildPool2dNode( ...@@ -96,6 +99,10 @@ void BuildPool2dNode(
pool2d = pool2d =
std::make_shared<ngraph::op::AvgPool>(x, ng_ksize_shape, ng_strides); std::make_shared<ngraph::op::AvgPool>(x, ng_ksize_shape, ng_strides);
} else { } else {
if ((ng_padding_below[0] == 0) && (ng_padding_below[1] == 0) &&
(ng_padding_above[0] == 0) && (ng_padding_above[1] == 0)) {
padding_exclusive = false;
}
pool2d = std::make_shared<ngraph::op::AvgPool>( pool2d = std::make_shared<ngraph::op::AvgPool>(
x, ng_ksize_shape, ng_strides, ng_padding_below, ng_padding_above, x, ng_ksize_shape, ng_strides, ng_padding_below, ng_padding_above,
!padding_exclusive); !padding_exclusive);
...@@ -163,6 +170,10 @@ void BuildPool2dGradNode( ...@@ -163,6 +170,10 @@ void BuildPool2dGradNode(
x->get_shape(), dout, ng_ksize_shape, ng_strides, ng_padding_below, x->get_shape(), dout, ng_ksize_shape, ng_strides, ng_padding_below,
ng_padding_above, !padding_exclusive); ng_padding_above, !padding_exclusive);
} else { } else {
if ((ng_padding_below[0] == 0) && (ng_padding_below[1] == 0) &&
(ng_padding_above[0] == 0) && (ng_padding_above[1] == 0)) {
padding_exclusive = false;
}
pool2d_grad = std::make_shared<ngraph::op::AvgPoolBackprop>( pool2d_grad = std::make_shared<ngraph::op::AvgPoolBackprop>(
x->get_shape(), dout, ng_ksize_shape, ng_strides, ng_padding_below, x->get_shape(), dout, ng_ksize_shape, ng_strides, ng_padding_below,
ng_padding_above, !padding_exclusive); ng_padding_above, !padding_exclusive);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册