diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 0534067254613008d7ff4e531ecb7368a0ad6925..40ab2ad7009e5f89a51a1854c109e47f9ef5df7b 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -334,7 +334,7 @@ std::unique_ptr Executor::Prepare( ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); } #ifdef PADDLE_WITH_NGRAPH - if (FLAGS_use_ngraph) { + if (FLAGS_use_ngraph && ctx->block_id_ == 0) { paddle::operators::NgraphEngine::FuseNgraphOps( ctx->prog_.Block(ctx->block_id_), &ctx->ops_); } diff --git a/paddle/fluid/operators/ngraph/ngraph_bridge.cc b/paddle/fluid/operators/ngraph/ngraph_bridge.cc index 4ff50935d6c78a01db222dcc8bcca3b22985d943..db8a7ca94a557d1d93b7dc73b2eee4a36d3783e3 100644 --- a/paddle/fluid/operators/ngraph/ngraph_bridge.cc +++ b/paddle/fluid/operators/ngraph/ngraph_bridge.cc @@ -36,14 +36,14 @@ bool NgraphBridge::isRegister(const std::string& str) { bool NgraphBridge::isSupported( const std::unique_ptr& op) { - static std::unordered_set skip_op_list{"reshape", "reshape2", - "lookup_table"}; + static std::unordered_set skip_op_list{ + "reshape", "reshape2", "lookup_table", "lookup_table_grad"}; bool result = true; auto& op_type = op->Type(); auto op_attrs = paddle::framework::AttrReader(op->Attrs()); if (!isRegister(op_type)) { if (skip_op_list.count(op_type)) { - if (op_type == "lookup_table") { + if (op_type == "lookup_table" || op_type == "lookup_table_grad") { if (op_attrs.Get("is_sparse") || (op_attrs.Get("padding_idx") != kNoPadding)) { result = false; diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index d93e3c504f82440c8773adda2dff73c11a9b4113..19d30a6f8387b528d3410091d54d7f9bf49e39b3 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -38,6 +38,10 @@ namespace operators { static ngraph::Shape Ddim2Shape(const framework::DDim& dims) { ngraph::Shape sp; + if (dims.size() == 1 && dims[0] == 0) { + sp.emplace_back(0); + return sp; + } for (int i = 0; i < dims.size(); ++i) { int k = dims[i]; k = k == 0 ? 1 : k; @@ -417,6 +421,15 @@ void NgraphEngine::BuildNgIO(const std::vector& ops_desc, } } } + // remove output duplicates + std::unordered_set var_out_set; + for (int i = static_cast(var_out_.size()) - 1; i >= 0; --i) { + std::string var_name = var_out_.at(i); + if (var_out_set.count(var_name)) { + var_out_.erase(var_out_.begin() + i); + } + var_out_set.insert(var_name); + } } void NgraphEngine::GetNgInputShape() { @@ -458,16 +471,8 @@ void NgraphEngine::BuildNgNodes() { } } -void NgraphEngine::RunInferShape() { - for (auto& op : fused_ops_) { - framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_); - op->RuntimeInferShape(scope_, place_, ctx); - } -} - void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) { Prepare(ctx); - RunInferShape(); GetNgInputShape(); BuildNgNodes(); ngraph_function_ = nullptr; @@ -626,6 +631,21 @@ void NgraphEngine::Run(const framework::Scope& scope, } } + for (auto& op : fused_ops_) { + framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_); + if (op->Type() == "reshape2_grad") { + auto xshape_name = op->Inputs().at("XShape").at(0); + auto* xshape_var = scope_.FindVar(xshape_name); + auto* xshape_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*xshape_var); + auto& xshape_ddim = xshape_tensor->dims(); + auto xgrad_name = op->Outputs().at(framework::GradVarName("X")).at(0); + auto* xgrad_var = scope_.FindVar(xgrad_name); + xgrad_var->GetMutable()->Resize(xshape_ddim); + } else { + op->RuntimeInferShape(scope_, place_, ctx); + } + } + std::vector> t_out = {}; for (size_t i = 0; i < p_var_out->size(); ++i) { auto vo = p_var_out->at(i); diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.h b/paddle/fluid/operators/ngraph/ngraph_engine.h index 0e36204a4470622f29f619c4369070c079b651dd..885b738b95b473ef3455f93b671a5f4bb0b6730d 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.h +++ b/paddle/fluid/operators/ngraph/ngraph_engine.h @@ -109,8 +109,6 @@ class NgraphEngine { void GetNgInputShape(); // Call ngraph bridge to map ops void BuildNgNodes(); - // run paddle RuntimeInferShape to get the tensor shape - void RunInferShape(); // build ngraph function call void BuildNgFunction(const framework::ExecutionContext& ctx); // Check cache for ngraph function or otherwise build the function diff --git a/paddle/fluid/operators/ngraph/ops/elementwise_binary_prepare_node.h b/paddle/fluid/operators/ngraph/ops/elementwise_binary_prepare_node.h index 8732932dedd4401853325b629877880cc90f6cb6..e4e17f5bb219bdf82db99fce2ea4fe5dbcb6e0c9 100644 --- a/paddle/fluid/operators/ngraph/ops/elementwise_binary_prepare_node.h +++ b/paddle/fluid/operators/ngraph/ops/elementwise_binary_prepare_node.h @@ -14,7 +14,9 @@ limitations under the License. */ #pragma once +#include #include +#include #include #include "ngraph/ngraph.hpp" @@ -42,11 +44,11 @@ ngraph::NodeVector ElementwiseBinaryNodePrepare( if (lhs_shape == rhs_shape) { return ngraph::NodeVector{lhs, rhs}; } + axis = (rhs_shape.size() == 0) ? lhs_shape.size() - 1 : axis; axis = (axis == -1 ? lhs_shape.size() - rhs_shape.size() : axis); PADDLE_ENFORCE(axis >= 0 && axis < (int)(lhs_shape.size()), "Axis should be in range [0, lhs_shape)"); paddle::platform::TrimTrailingSingularDims(&rhs_shape); - axis = (rhs_shape.size() == 0) ? lhs_shape.size() : axis; int pre, n, post; paddle::platform::GetMidDims(lhs_shape, rhs_shape, axis, &pre, &n, &post); diff --git a/paddle/fluid/operators/ngraph/ops/mul_op.h b/paddle/fluid/operators/ngraph/ops/mul_op.h index d13665864b8950436298b7cf685c803593007803..cb46478ee8ad4f4c51a6ff9d6f5de4e66f6a505f 100644 --- a/paddle/fluid/operators/ngraph/ops/mul_op.h +++ b/paddle/fluid/operators/ngraph/ops/mul_op.h @@ -35,6 +35,7 @@ static void BuildMulNode( int y_num_col_dims = op_attrs.Get("y_num_col_dims"); auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map); + int y_rank = y->get_shape().size(); auto x_reshape = x; auto y_reshape = y; @@ -52,10 +53,14 @@ static void BuildMulNode( std::shared_ptr out = std::make_shared(x_reshape, y_reshape); - auto dummy_out = paddle::platform::GetOutputNode(op, "Out", ngb_node_map); - if (dummy_out && dummy_out->get_shape() != out->get_shape()) { - out = paddle::platform::NgReshaper(out, dummy_out->get_shape()); + ngraph::Shape out_shape; + for (int i = 0; i < x_num_col_dims; ++i) { + out_shape.push_back(x->get_shape()[i]); } + for (int i = y_num_col_dims; i < y_rank; ++i) { + out_shape.push_back(y->get_shape()[i]); + } + out = paddle::platform::NgReshaper(out, out_shape); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); } diff --git a/paddle/fluid/operators/ngraph/ops/pool2d_op.h b/paddle/fluid/operators/ngraph/ops/pool2d_op.h index c7b9c9316171a448d16ed68339f5754d25f3cabd..e5542d4715740ad9f2ab7315dcfa20434a08f3fa 100644 --- a/paddle/fluid/operators/ngraph/ops/pool2d_op.h +++ b/paddle/fluid/operators/ngraph/ops/pool2d_op.h @@ -60,17 +60,20 @@ void BuildPool2dNode( ngraph::Strides ng_strides{static_cast(strides.at(0)), static_cast(strides.at(1))}; - auto ComputeCeiledOutput = [](size_t in, size_t k, size_t p, size_t s) { + auto ComputeFlooredOutput = [](size_t in, size_t k, size_t p, size_t s) { return (in - k + 2 * p) / s + 1; }; + auto ComputeCeiledOutput = [](size_t in, size_t k, size_t p, size_t s) { + return ceil(static_cast(in - k + 2 * p) / s) + 1; + }; if (op_attrs.Get("ceil_mode")) { - auto dummy_out = paddle::platform::GetOutputNode(op, "Out", ngb_node_map); - auto dummpy_shape = dummy_out->get_shape(); for (size_t i = 0; i < ng_padding_above.size(); ++i) { - auto desired_size = ComputeCeiledOutput(x_shape[i + 2], ksize[i], - paddings[i], strides[i]); - if (desired_size != dummpy_shape[i + 2]) { + auto ceiled_size = ComputeCeiledOutput(x_shape[i + 2], ksize[i], + paddings[i], strides[i]); + auto floored_size = ComputeFlooredOutput(x_shape[i + 2], ksize[i], + paddings[i], strides[i]); + if (ceiled_size != floored_size) { ng_padding_above[i] += strides[i]; } } @@ -96,6 +99,10 @@ void BuildPool2dNode( pool2d = std::make_shared(x, ng_ksize_shape, ng_strides); } else { + if ((ng_padding_below[0] == 0) && (ng_padding_below[1] == 0) && + (ng_padding_above[0] == 0) && (ng_padding_above[1] == 0)) { + padding_exclusive = false; + } pool2d = std::make_shared( x, ng_ksize_shape, ng_strides, ng_padding_below, ng_padding_above, !padding_exclusive); @@ -163,6 +170,10 @@ void BuildPool2dGradNode( x->get_shape(), dout, ng_ksize_shape, ng_strides, ng_padding_below, ng_padding_above, !padding_exclusive); } else { + if ((ng_padding_below[0] == 0) && (ng_padding_below[1] == 0) && + (ng_padding_above[0] == 0) && (ng_padding_above[1] == 0)) { + padding_exclusive = false; + } pool2d_grad = std::make_shared( x->get_shape(), dout, ng_ksize_shape, ng_strides, ng_padding_below, ng_padding_above, !padding_exclusive);