提交 bfdab00e 编写于 作者: L luotao1

Merge branch 'develop' into core_opt_choose_kernel

...@@ -70,6 +70,7 @@ pass_library(conv_affine_channel_fuse_pass inference) ...@@ -70,6 +70,7 @@ pass_library(conv_affine_channel_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base) pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base)
# There may be many transpose-flatten structures in a model, and the output of # There may be many transpose-flatten structures in a model, and the output of
# these structures will be used as inputs to the concat Op. This pattern will # these structures will be used as inputs to the concat Op. This pattern will
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/runtime_context_cache_pass.h"
#include <memory>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> RuntimeContextCachePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true);
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(runtime_context_cache_pass,
paddle::framework::ir::RuntimeContextCachePass);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class RuntimeContextCachePass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -876,12 +876,27 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig( ...@@ -876,12 +876,27 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeContext ctx(Inputs(), Outputs(), scope); if (!HasAttr(kEnableCacheRuntimeContext)) {
RuntimeContext ctx(Inputs(), Outputs(), scope);
RunImpl(scope, place, &ctx);
} else {
const Scope* cur_scope = &scope;
if (!runtime_ctx_ || pre_scope_ != cur_scope) {
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
pre_scope_ = cur_scope;
}
RunImpl(scope, place, runtime_ctx_.get());
}
}
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place,
RuntimeContext* runtime_ctx) const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
if (!kernel_type_) { if (!kernel_type_) {
ChooseKernel(ctx, scope, place); ChooseKernel(*runtime_ctx, scope, place);
} }
std::vector<KernelConfig>* kernel_configs = GetKernelConfig(*kernel_type_); std::vector<KernelConfig>* kernel_configs = GetKernelConfig(*kernel_type_);
...@@ -889,7 +904,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -889,7 +904,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = auto* transfer_scope =
PrepareData(scope, *kernel_type_, &transfered_inplace_vars, &ctx); PrepareData(scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx);
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = const Scope& exec_scope =
...@@ -900,13 +915,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -900,13 +915,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) { if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) {
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs. // not Scope. Imperative mode only pass inputs and get outputs.
(*kernel_func_)( (*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx,
ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs)); kernel_configs));
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered. // there is inplace variable has been transfered.
......
...@@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; ...@@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO";
/// Variables with this suffix are the new Gradient. /// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@"; constexpr char kNewGradSuffix[] = "@NEWGRAD@";
/// RuntimeContext is used to relate input/output names of Operator with
/// the corresponding variables in name scope.
/// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same
/// name scope, since the input/output names of this Op do not change in the
/// execution, RuntimeContext could be created only at the first iteration of
/// this Op's execution to save the elapsed time.
constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@";
/// If an Op has this attribute, all its kernels should calculate output /// If an Op has this attribute, all its kernels should calculate output
/// variable's shape in the corresponding Compute() function. And /// variable's shape in the corresponding Compute() function. And
/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape() /// OperatorWithKernel::RunImpl() would skip call this Op's InferShape()
...@@ -456,6 +464,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -456,6 +464,8 @@ class OperatorWithKernel : public OperatorBase {
// same. // same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope, const platform::Place& place,
RuntimeContext* runtime_ctx) const;
/** /**
* Transfer data from scope to a transfered scope. If there is no data need to * Transfer data from scope to a transfered scope. If there is no data need to
...@@ -479,6 +489,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -479,6 +489,8 @@ class OperatorWithKernel : public OperatorBase {
mutable OpKernelConfigsMap kernel_configs_map_; mutable OpKernelConfigsMap kernel_configs_map_;
mutable std::unique_ptr<OpKernelType> kernel_type_; mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::unique_ptr<OpKernelFunc> kernel_func_; mutable std::unique_ptr<OpKernelFunc> kernel_func_;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
...@@ -202,6 +202,7 @@ void AnalysisConfig::Update() { ...@@ -202,6 +202,7 @@ void AnalysisConfig::Update() {
// Append after the Affine_channel_conv_fuse pass. // Append after the Affine_channel_conv_fuse pass.
pass_builder()->InsertPass(3, "tensorrt_subgraph_pass"); pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
} }
pass_builder()->DeletePass("runtime_context_cache_pass");
} }
if (use_mkldnn_) { if (use_mkldnn_) {
......
...@@ -80,6 +80,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -80,6 +80,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_act_fuse_pass", // "conv_elementwise_add_act_fuse_pass", //
"conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add2_act_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
"runtime_context_cache_pass", //
#endif #endif
}); });
...@@ -115,6 +116,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { ...@@ -115,6 +116,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
"is_test_pass", // "is_test_pass", //
"identity_scale_op_clean_pass", // "identity_scale_op_clean_pass", //
"runtime_context_cache_pass", //
}); });
use_gpu_ = false; use_gpu_ = false;
} }
......
...@@ -325,7 +325,8 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc, ...@@ -325,7 +325,8 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
const bool is_output = outputs.find(var_name) != outputs.end(); const bool is_output = outputs.find(var_name) != outputs.end();
if (!is_output && if (!is_output &&
std::find(var_in_.begin(), var_in_.end(), var_name) == std::find(var_in_.begin(), var_in_.end(), var_name) ==
var_in_.end()) { var_in_.end() &&
scope_.FindVar(var_name)) {
// fill var_in here to keep lhs and rhs order // fill var_in here to keep lhs and rhs order
this->var_in_.emplace_back(var_name); this->var_in_.emplace_back(var_name);
} }
......
...@@ -27,13 +27,9 @@ namespace paddle { ...@@ -27,13 +27,9 @@ namespace paddle {
namespace operators { namespace operators {
namespace ngraphs { namespace ngraphs {
void BuildCrossEntropyNode( std::shared_ptr<ngraph::Node> GetCrossEntropy(
const std::shared_ptr<paddle::framework::OperatorBase>& op, std::shared_ptr<ngraph::Node> x, std::shared_ptr<ngraph::Node> label,
std::shared_ptr< const bool is_soft_label, int ignore_index) {
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto label_shape = label->get_shape(); auto label_shape = label->get_shape();
auto x_shape = x->get_shape(); auto x_shape = x->get_shape();
auto label_rank = label_shape.size(); auto label_rank = label_shape.size();
...@@ -46,18 +42,16 @@ void BuildCrossEntropyNode( ...@@ -46,18 +42,16 @@ void BuildCrossEntropyNode(
label_2d = paddle::platform::NgReshaper(label, label_2d_shape); label_2d = paddle::platform::NgReshaper(label, label_2d_shape);
} }
if (x_rank > 2) { if (x_rank > 2) {
x_2d_shape = paddle::platform::FlattenTo2d(x_shape, x_rank - 1); x_2d_shape = platform::FlattenTo2d(x_shape, x_rank - 1);
x_2d = paddle::platform::NgReshaper(x, x_2d_shape); x_2d = platform::NgReshaper(x, x_2d_shape);
} }
auto batch_size = x_2d_shape.at(0); auto batch_size = x_2d_shape.at(0);
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
std::shared_ptr<ngraph::Node> node_1_hot = label_2d; std::shared_ptr<ngraph::Node> node_1_hot = label_2d;
if (!is_soft_label) { if (!is_soft_label) {
auto label_1d = paddle::platform::NgReshaper( auto label_1d =
label_2d, ngraph::Shape{label_2d_shape.at(0)}); platform::NgReshaper(label_2d, ngraph::Shape{label_2d_shape.at(0)});
node_1_hot = std::make_shared<ngraph::op::OneHot>(label_1d, x_2d_shape, 1); node_1_hot = std::make_shared<ngraph::op::OneHot>(label_1d, x_2d_shape, 1);
} }
if (x->get_element_type() != node_1_hot->get_element_type()) { if (x->get_element_type() != node_1_hot->get_element_type()) {
...@@ -76,11 +70,9 @@ void BuildCrossEntropyNode( ...@@ -76,11 +70,9 @@ void BuildCrossEntropyNode(
auto node_sum = auto node_sum =
std::make_shared<ngraph::op::Sum>(node_mul, ngraph::AxisSet{1}); std::make_shared<ngraph::op::Sum>(node_mul, ngraph::AxisSet{1});
auto node_neg = std::make_shared<ngraph::op::Negative>(node_sum); auto node_neg = std::make_shared<ngraph::op::Negative>(node_sum);
auto xe = auto xe = platform::NgReshaper(node_neg, ngraph::Shape{batch_size, 1});
paddle::platform::NgReshaper(node_neg, ngraph::Shape{batch_size, 1});
if (!is_soft_label) { if (!is_soft_label) {
auto ignore_index = op_attrs.Get<int>("ignore_index");
auto ignore_node = ngraph::op::Constant::create( auto ignore_node = ngraph::op::Constant::create(
label->get_element_type(), label_2d_shape, {ignore_index}); label->get_element_type(), label_2d_shape, {ignore_index});
auto not_equal_node = auto not_equal_node =
...@@ -89,21 +81,13 @@ void BuildCrossEntropyNode( ...@@ -89,21 +81,13 @@ void BuildCrossEntropyNode(
xe->get_element_type()); xe->get_element_type());
xe = xe * mask; xe = xe * mask;
} }
return xe;
paddle::platform::SetOutputNode(op, "Y", xe, ngb_node_map);
} }
void BuildCrossEntropyGradNode( std::shared_ptr<ngraph::Node> GetCrossEntropyGrad(
const std::shared_ptr<paddle::framework::OperatorBase>& op, std::shared_ptr<ngraph::Node> x, std::shared_ptr<ngraph::Node> label,
std::shared_ptr< std::shared_ptr<ngraph::Node> dy, const bool is_soft_label,
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> int ignore_index) {
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto dy = paddle::platform::GetInputNode(op, "Y@GRAD", ngb_node_map);
auto x_shape = x->get_shape(); auto x_shape = x->get_shape();
auto rank = x_shape.size(); auto rank = x_shape.size();
...@@ -111,9 +95,8 @@ void BuildCrossEntropyGradNode( ...@@ -111,9 +95,8 @@ void BuildCrossEntropyGradNode(
if (!is_soft_label) { if (!is_soft_label) {
auto label_shape = label->get_shape(); auto label_shape = label->get_shape();
label_shape.pop_back(); label_shape.pop_back();
label = paddle::platform::NgReshaper(label, label_shape); label = platform::NgReshaper(label, label_shape);
auto ignore_index = op_attrs.Get<int>("ignore_index");
auto ignore_node = ngraph::op::Constant::create( auto ignore_node = ngraph::op::Constant::create(
label->get_element_type(), label_shape, {ignore_index}); label->get_element_type(), label_shape, {ignore_index});
auto not_equal_node = auto not_equal_node =
...@@ -128,7 +111,7 @@ void BuildCrossEntropyGradNode( ...@@ -128,7 +111,7 @@ void BuildCrossEntropyGradNode(
auto dy_shape = dy->get_shape(); auto dy_shape = dy->get_shape();
dy_shape.pop_back(); dy_shape.pop_back();
auto dy_reshape = paddle::platform::NgReshaper(dy, dy_shape); auto dy_reshape = platform::NgReshaper(dy, dy_shape);
auto dy_bcast = std::make_shared<ngraph::op::Broadcast>( auto dy_bcast = std::make_shared<ngraph::op::Broadcast>(
dy_reshape, x_shape, ngraph::AxisSet{rank - 1}); dy_reshape, x_shape, ngraph::AxisSet{rank - 1});
if (x->get_element_type() != label->get_element_type()) { if (x->get_element_type() != label->get_element_type()) {
...@@ -140,7 +123,35 @@ void BuildCrossEntropyGradNode( ...@@ -140,7 +123,35 @@ void BuildCrossEntropyGradNode(
if (!is_soft_label) { if (!is_soft_label) {
xe_grad = xe_grad * mask; xe_grad = xe_grad * mask;
} }
return xe_grad;
}
void BuildCrossEntropyNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
int ignore_index = op_attrs.Get<int>("ignore_index");
auto xe = GetCrossEntropy(x, label, is_soft_label, ignore_index);
paddle::platform::SetOutputNode(op, "Y", xe, ngb_node_map);
}
void BuildCrossEntropyGradNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
int ignore_index = op_attrs.Get<int>("ignore_index");
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto dy = paddle::platform::GetInputNode(op, "Y@GRAD", ngb_node_map);
auto xe_grad = GetCrossEntropyGrad(x, label, dy, is_soft_label, ignore_index);
paddle::platform::SetOutputNode(op, "X@GRAD", xe_grad, ngb_node_map); paddle::platform::SetOutputNode(op, "X@GRAD", xe_grad, ngb_node_map);
} }
} // namespace ngraphs } // namespace ngraphs
......
...@@ -27,12 +27,7 @@ namespace paddle { ...@@ -27,12 +27,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace ngraphs { namespace ngraphs {
void BuildSoftmaxNode( std::shared_ptr<ngraph::Node> GetSoftmax(std::shared_ptr<ngraph::Node> x) {
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto x_shape = x->get_shape(); auto x_shape = x->get_shape();
int rank = x_shape.size(); int rank = x_shape.size();
auto x_2d_shape = paddle::platform::FlattenTo2d(x_shape, rank - 1); auto x_2d_shape = paddle::platform::FlattenTo2d(x_shape, rank - 1);
...@@ -47,16 +42,11 @@ void BuildSoftmaxNode( ...@@ -47,16 +42,11 @@ void BuildSoftmaxNode(
-64., x_shifted); -64., x_shifted);
auto softmax = auto softmax =
std::make_shared<ngraph::op::Softmax>(x_clipped, ngraph::AxisSet{1}); std::make_shared<ngraph::op::Softmax>(x_clipped, ngraph::AxisSet{1});
paddle::platform::SetOutputNode(op, "Out", softmax, ngb_node_map); return softmax;
} }
void BuildSoftmaxGradNode( std::shared_ptr<ngraph::Node> GetSoftmaxGrad(
const std::shared_ptr<paddle::framework::OperatorBase>& op, std::shared_ptr<ngraph::Node> out, std::shared_ptr<ngraph::Node> dout) {
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto out = paddle::platform::GetInputNode(op, "Out", ngb_node_map);
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
auto out_shape = out->get_shape(); auto out_shape = out->get_shape();
int rank = out_shape.size(); int rank = out_shape.size();
auto out_2d_shape = paddle::platform::FlattenTo2d(out_shape, rank - 1); auto out_2d_shape = paddle::platform::FlattenTo2d(out_shape, rank - 1);
...@@ -70,6 +60,27 @@ void BuildSoftmaxGradNode( ...@@ -70,6 +60,27 @@ void BuildSoftmaxGradNode(
auto node_bcast = std::make_shared<ngraph::op::Broadcast>( auto node_bcast = std::make_shared<ngraph::op::Broadcast>(
node_sum, out_2d_shape, ngraph::AxisSet{1}); node_sum, out_2d_shape, ngraph::AxisSet{1});
auto dx = (dout - node_bcast) * out; auto dx = (dout - node_bcast) * out;
return dx;
}
void BuildSoftmaxNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto softmax = GetSoftmax(x);
paddle::platform::SetOutputNode(op, "Out", softmax, ngb_node_map);
}
void BuildSoftmaxGradNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto out = paddle::platform::GetInputNode(op, "Out", ngb_node_map);
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
auto dx = GetSoftmaxGrad(out, dout);
paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map);
} }
} // namespace ngraphs } // namespace ngraphs
......
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/operators/ngraph/ops/cross_entropy_op.h"
#include "paddle/fluid/operators/ngraph/ops/softmax_op.h"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace operators {
namespace ngraphs {
void BuildSoftmaxWithCrossEntropyNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto logits = paddle::platform::GetInputNode(op, "Logits", ngb_node_map);
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto softmax = paddle::operators::ngraphs::GetSoftmax(logits);
auto op_attrs = framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
int ignore_index = op_attrs.Get<int>("ignore_index");
auto xe = paddle::operators::ngraphs::GetCrossEntropy(
softmax, label, is_soft_label, ignore_index);
paddle::platform::SetOutputNode(op, "Softmax", softmax, ngb_node_map);
paddle::platform::SetOutputNode(op, "Loss", xe, ngb_node_map);
}
void BuildSoftmaxWithCrossEntropyGradNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto softmax = paddle::platform::GetInputNode(op, "Softmax", ngb_node_map);
auto loss_grad =
paddle::platform::GetInputNode(op, "Loss@GRAD", ngb_node_map);
auto softmax_shape = softmax->get_shape();
auto rank = softmax_shape.size();
if (!is_soft_label) {
auto label_shape = label->get_shape();
label_shape.pop_back();
label = platform::NgReshaper(label, label_shape);
label =
std::make_shared<ngraph::op::OneHot>(label, softmax_shape, rank - 1);
}
auto loss_grad_shape = loss_grad->get_shape();
loss_grad_shape.pop_back();
auto loss_grad_reshape = platform::NgReshaper(loss_grad, loss_grad_shape);
auto loss_grad_bcast = std::make_shared<ngraph::op::Broadcast>(
loss_grad_reshape, softmax_shape, ngraph::AxisSet{rank - 1});
if (softmax->get_element_type() != label->get_element_type()) {
label = std::make_shared<ngraph::op::Convert>(label,
softmax->get_element_type());
}
auto logits_grad = loss_grad_bcast * (softmax - label);
paddle::platform::SetOutputNode(op, "Logits@GRAD", logits_grad, ngb_node_map);
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
REGISTER_NG_OP(softmax_with_cross_entropy, BuildSoftmaxWithCrossEntropyNode);
REGISTER_NG_OP(softmax_with_cross_entropy_grad,
BuildSoftmaxWithCrossEntropyGradNode);
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from paddle.fluid.tests.unittests.test_softmax_with_cross_entropy_op import TestSoftmaxWithCrossEntropyOp, TestSoftmaxWithCrossEntropyOp2, TestSoftmaxWithCrossEntropyOp3
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册