diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 8d2cc5adec1fe3fb9280981ef2d128fb0803fa24..a79a53867d85e91250ac4810caa5806c25f35fee 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -70,6 +70,7 @@ pass_library(conv_affine_channel_fuse_pass inference) pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(identity_scale_op_clean_pass base) pass_library(sync_batch_norm_pass base) +pass_library(runtime_context_cache_pass base) # There may be many transpose-flatten structures in a model, and the output of # these structures will be used as inputs to the concat Op. This pattern will diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..67b29512c4cf3512e4b2b4b5a18ba60a3d9120dc --- /dev/null +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/runtime_context_cache_pass.h" +#include +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr RuntimeContextCachePass::ApplyImpl( + std::unique_ptr graph) const { + VLOG(3) << "Applies Runtime Context Cache strategy."; + for (const Node* n : graph->Nodes()) { + if (n->IsOp()) { + n->Op()->SetAttr(kEnableCacheRuntimeContext, true); + } + } + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(runtime_context_cache_pass, + paddle::framework::ir::RuntimeContextCachePass); diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.h b/paddle/fluid/framework/ir/runtime_context_cache_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..a6cf1a9ae5035f185dd3ab52bf0762a6eaf0f6e5 --- /dev/null +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class RuntimeContextCachePass : public Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index a04646cd01765dee83de863e9cb71392f484f5a6..1d3a38cc286fdbc26d62a02e9d9086feb2826a07 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -876,12 +876,27 @@ std::vector* OperatorWithKernel::GetKernelConfig( void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { - RuntimeContext ctx(Inputs(), Outputs(), scope); + if (!HasAttr(kEnableCacheRuntimeContext)) { + RuntimeContext ctx(Inputs(), Outputs(), scope); + RunImpl(scope, place, &ctx); + } else { + const Scope* cur_scope = &scope; + if (!runtime_ctx_ || pre_scope_ != cur_scope) { + runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); + pre_scope_ = cur_scope; + } + RunImpl(scope, place, runtime_ctx_.get()); + } +} + +void OperatorWithKernel::RunImpl(const Scope& scope, + const platform::Place& place, + RuntimeContext* runtime_ctx) const { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); if (!kernel_type_) { - ChooseKernel(ctx, scope, place); + ChooseKernel(*runtime_ctx, scope, place); } std::vector* kernel_configs = GetKernelConfig(*kernel_type_); @@ -889,7 +904,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // do data transformScope &transfer_scope; std::vector transfered_inplace_vars; auto* transfer_scope = - PrepareData(scope, *kernel_type_, &transfered_inplace_vars, &ctx); + PrepareData(scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx); // exec scope is the scope that kernel actually executed on. const Scope& exec_scope = @@ -900,13 +915,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) { - RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); + RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx); this->InferShape(&infer_shape_ctx); } // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // not Scope. Imperative mode only pass inputs and get outputs. - (*kernel_func_)( - ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs)); + (*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx, + kernel_configs)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index b9936a2350c45ab5b91cb50e1f6e66e28c63142d..fb7829a12cd1efcab115b5136025c6f324505f2f 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; /// Variables with this suffix are the new Gradient. constexpr char kNewGradSuffix[] = "@NEWGRAD@"; +/// RuntimeContext is used to relate input/output names of Operator with +/// the corresponding variables in name scope. +/// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same +/// name scope, since the input/output names of this Op do not change in the +/// execution, RuntimeContext could be created only at the first iteration of +/// this Op's execution to save the elapsed time. +constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@"; + /// If an Op has this attribute, all its kernels should calculate output /// variable's shape in the corresponding Compute() function. And /// OperatorWithKernel::RunImpl() would skip call this Op's InferShape() @@ -456,6 +464,8 @@ class OperatorWithKernel : public OperatorBase { // same. proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; void RunImpl(const Scope& scope, const platform::Place& place) const final; + void RunImpl(const Scope& scope, const platform::Place& place, + RuntimeContext* runtime_ctx) const; /** * Transfer data from scope to a transfered scope. If there is no data need to @@ -479,6 +489,8 @@ class OperatorWithKernel : public OperatorBase { mutable OpKernelConfigsMap kernel_configs_map_; mutable std::unique_ptr kernel_type_; mutable std::unique_ptr kernel_func_; + mutable std::unique_ptr runtime_ctx_; + mutable const Scope* pre_scope_ = nullptr; }; extern bool OpSupportGPU(const std::string& op_type); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 92526f4e74a217aa2cdfd43f258846ada54b9374..1be25de497346913f24eec147a2db58b0f7065f4 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -202,6 +202,7 @@ void AnalysisConfig::Update() { // Append after the Affine_channel_conv_fuse pass. pass_builder()->InsertPass(3, "tensorrt_subgraph_pass"); } + pass_builder()->DeletePass("runtime_context_cache_pass"); } if (use_mkldnn_) { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 92c24647e87a096e7cfbbf69876b678fe48842a4..22c527cfc117a5e6ababf264744745e41e0bf71a 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -80,6 +80,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add_act_fuse_pass", // "conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add_fuse_pass", // + "runtime_context_cache_pass", // #endif }); @@ -115,6 +116,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "conv_eltwiseadd_bn_fuse_pass", // "is_test_pass", // "identity_scale_op_clean_pass", // + "runtime_context_cache_pass", // }); use_gpu_ = false; } diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index 014c9ecca4aa7875e5188b33595f5e8f79c1a9db..9f73bbc1fdc72766a0b57bc72c62d208277c2f20 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -325,7 +325,8 @@ void NgraphEngine::BuildNgIO(const std::vector& ops_desc, const bool is_output = outputs.find(var_name) != outputs.end(); if (!is_output && std::find(var_in_.begin(), var_in_.end(), var_name) == - var_in_.end()) { + var_in_.end() && + scope_.FindVar(var_name)) { // fill var_in here to keep lhs and rhs order this->var_in_.emplace_back(var_name); } diff --git a/paddle/fluid/operators/ngraph/ops/cross_entropy_op.h b/paddle/fluid/operators/ngraph/ops/cross_entropy_op.h index be36b9d21ef6ebe5c11d783462e7dc564afe2aba..c92ebb7e96fa22f8fd463c5837134cd74542766c 100644 --- a/paddle/fluid/operators/ngraph/ops/cross_entropy_op.h +++ b/paddle/fluid/operators/ngraph/ops/cross_entropy_op.h @@ -27,13 +27,9 @@ namespace paddle { namespace operators { namespace ngraphs { -void BuildCrossEntropyNode( - const std::shared_ptr& op, - std::shared_ptr< - std::unordered_map>> - ngb_node_map) { - auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); - auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map); +std::shared_ptr GetCrossEntropy( + std::shared_ptr x, std::shared_ptr label, + const bool is_soft_label, int ignore_index) { auto label_shape = label->get_shape(); auto x_shape = x->get_shape(); auto label_rank = label_shape.size(); @@ -46,18 +42,16 @@ void BuildCrossEntropyNode( label_2d = paddle::platform::NgReshaper(label, label_2d_shape); } if (x_rank > 2) { - x_2d_shape = paddle::platform::FlattenTo2d(x_shape, x_rank - 1); - x_2d = paddle::platform::NgReshaper(x, x_2d_shape); + x_2d_shape = platform::FlattenTo2d(x_shape, x_rank - 1); + x_2d = platform::NgReshaper(x, x_2d_shape); } auto batch_size = x_2d_shape.at(0); - auto op_attrs = paddle::framework::AttrReader(op->Attrs()); - const bool is_soft_label = op_attrs.Get("soft_label"); std::shared_ptr node_1_hot = label_2d; if (!is_soft_label) { - auto label_1d = paddle::platform::NgReshaper( - label_2d, ngraph::Shape{label_2d_shape.at(0)}); + auto label_1d = + platform::NgReshaper(label_2d, ngraph::Shape{label_2d_shape.at(0)}); node_1_hot = std::make_shared(label_1d, x_2d_shape, 1); } if (x->get_element_type() != node_1_hot->get_element_type()) { @@ -76,11 +70,9 @@ void BuildCrossEntropyNode( auto node_sum = std::make_shared(node_mul, ngraph::AxisSet{1}); auto node_neg = std::make_shared(node_sum); - auto xe = - paddle::platform::NgReshaper(node_neg, ngraph::Shape{batch_size, 1}); + auto xe = platform::NgReshaper(node_neg, ngraph::Shape{batch_size, 1}); if (!is_soft_label) { - auto ignore_index = op_attrs.Get("ignore_index"); auto ignore_node = ngraph::op::Constant::create( label->get_element_type(), label_2d_shape, {ignore_index}); auto not_equal_node = @@ -89,21 +81,13 @@ void BuildCrossEntropyNode( xe->get_element_type()); xe = xe * mask; } - - paddle::platform::SetOutputNode(op, "Y", xe, ngb_node_map); + return xe; } -void BuildCrossEntropyGradNode( - const std::shared_ptr& op, - std::shared_ptr< - std::unordered_map>> - ngb_node_map) { - auto op_attrs = paddle::framework::AttrReader(op->Attrs()); - const bool is_soft_label = op_attrs.Get("soft_label"); - - auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); - auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map); - auto dy = paddle::platform::GetInputNode(op, "Y@GRAD", ngb_node_map); +std::shared_ptr GetCrossEntropyGrad( + std::shared_ptr x, std::shared_ptr label, + std::shared_ptr dy, const bool is_soft_label, + int ignore_index) { auto x_shape = x->get_shape(); auto rank = x_shape.size(); @@ -111,9 +95,8 @@ void BuildCrossEntropyGradNode( if (!is_soft_label) { auto label_shape = label->get_shape(); label_shape.pop_back(); - label = paddle::platform::NgReshaper(label, label_shape); + label = platform::NgReshaper(label, label_shape); - auto ignore_index = op_attrs.Get("ignore_index"); auto ignore_node = ngraph::op::Constant::create( label->get_element_type(), label_shape, {ignore_index}); auto not_equal_node = @@ -128,7 +111,7 @@ void BuildCrossEntropyGradNode( auto dy_shape = dy->get_shape(); dy_shape.pop_back(); - auto dy_reshape = paddle::platform::NgReshaper(dy, dy_shape); + auto dy_reshape = platform::NgReshaper(dy, dy_shape); auto dy_bcast = std::make_shared( dy_reshape, x_shape, ngraph::AxisSet{rank - 1}); if (x->get_element_type() != label->get_element_type()) { @@ -140,7 +123,35 @@ void BuildCrossEntropyGradNode( if (!is_soft_label) { xe_grad = xe_grad * mask; } + return xe_grad; +} +void BuildCrossEntropyNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); + auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map); + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + const bool is_soft_label = op_attrs.Get("soft_label"); + int ignore_index = op_attrs.Get("ignore_index"); + auto xe = GetCrossEntropy(x, label, is_soft_label, ignore_index); + paddle::platform::SetOutputNode(op, "Y", xe, ngb_node_map); +} + +void BuildCrossEntropyGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + const bool is_soft_label = op_attrs.Get("soft_label"); + int ignore_index = op_attrs.Get("ignore_index"); + auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); + auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map); + auto dy = paddle::platform::GetInputNode(op, "Y@GRAD", ngb_node_map); + auto xe_grad = GetCrossEntropyGrad(x, label, dy, is_soft_label, ignore_index); paddle::platform::SetOutputNode(op, "X@GRAD", xe_grad, ngb_node_map); } } // namespace ngraphs diff --git a/paddle/fluid/operators/ngraph/ops/softmax_op.h b/paddle/fluid/operators/ngraph/ops/softmax_op.h index 7d5720c460c4194ce06670a715b8d7ff4435bb2a..174b7a91a8dd0e3edb06f224c3914e24c6c4a96d 100644 --- a/paddle/fluid/operators/ngraph/ops/softmax_op.h +++ b/paddle/fluid/operators/ngraph/ops/softmax_op.h @@ -27,12 +27,7 @@ namespace paddle { namespace operators { namespace ngraphs { -void BuildSoftmaxNode( - const std::shared_ptr& op, - std::shared_ptr< - std::unordered_map>> - ngb_node_map) { - auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); +std::shared_ptr GetSoftmax(std::shared_ptr x) { auto x_shape = x->get_shape(); int rank = x_shape.size(); auto x_2d_shape = paddle::platform::FlattenTo2d(x_shape, rank - 1); @@ -47,16 +42,11 @@ void BuildSoftmaxNode( -64., x_shifted); auto softmax = std::make_shared(x_clipped, ngraph::AxisSet{1}); - paddle::platform::SetOutputNode(op, "Out", softmax, ngb_node_map); + return softmax; } -void BuildSoftmaxGradNode( - const std::shared_ptr& op, - std::shared_ptr< - std::unordered_map>> - ngb_node_map) { - auto out = paddle::platform::GetInputNode(op, "Out", ngb_node_map); - auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); +std::shared_ptr GetSoftmaxGrad( + std::shared_ptr out, std::shared_ptr dout) { auto out_shape = out->get_shape(); int rank = out_shape.size(); auto out_2d_shape = paddle::platform::FlattenTo2d(out_shape, rank - 1); @@ -70,6 +60,27 @@ void BuildSoftmaxGradNode( auto node_bcast = std::make_shared( node_sum, out_2d_shape, ngraph::AxisSet{1}); auto dx = (dout - node_bcast) * out; + return dx; +} + +void BuildSoftmaxNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map); + auto softmax = GetSoftmax(x); + paddle::platform::SetOutputNode(op, "Out", softmax, ngb_node_map); +} + +void BuildSoftmaxGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto out = paddle::platform::GetInputNode(op, "Out", ngb_node_map); + auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); + auto dx = GetSoftmaxGrad(out, dout); paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); } } // namespace ngraphs diff --git a/paddle/fluid/operators/ngraph/ops/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/ngraph/ops/softmax_with_cross_entropy_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a6bdf4de9522e08caf4a9ae606db8277f98cdab3 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/softmax_with_cross_entropy_op.h @@ -0,0 +1,90 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/operators/ngraph/ops/cross_entropy_op.h" +#include "paddle/fluid/operators/ngraph/ops/softmax_op.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +void BuildSoftmaxWithCrossEntropyNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto logits = paddle::platform::GetInputNode(op, "Logits", ngb_node_map); + auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map); + auto softmax = paddle::operators::ngraphs::GetSoftmax(logits); + + auto op_attrs = framework::AttrReader(op->Attrs()); + const bool is_soft_label = op_attrs.Get("soft_label"); + int ignore_index = op_attrs.Get("ignore_index"); + auto xe = paddle::operators::ngraphs::GetCrossEntropy( + softmax, label, is_soft_label, ignore_index); + + paddle::platform::SetOutputNode(op, "Softmax", softmax, ngb_node_map); + paddle::platform::SetOutputNode(op, "Loss", xe, ngb_node_map); +} + +void BuildSoftmaxWithCrossEntropyGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = framework::AttrReader(op->Attrs()); + const bool is_soft_label = op_attrs.Get("soft_label"); + auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map); + auto softmax = paddle::platform::GetInputNode(op, "Softmax", ngb_node_map); + auto loss_grad = + paddle::platform::GetInputNode(op, "Loss@GRAD", ngb_node_map); + auto softmax_shape = softmax->get_shape(); + auto rank = softmax_shape.size(); + if (!is_soft_label) { + auto label_shape = label->get_shape(); + label_shape.pop_back(); + label = platform::NgReshaper(label, label_shape); + + label = + std::make_shared(label, softmax_shape, rank - 1); + } + + auto loss_grad_shape = loss_grad->get_shape(); + loss_grad_shape.pop_back(); + auto loss_grad_reshape = platform::NgReshaper(loss_grad, loss_grad_shape); + auto loss_grad_bcast = std::make_shared( + loss_grad_reshape, softmax_shape, ngraph::AxisSet{rank - 1}); + if (softmax->get_element_type() != label->get_element_type()) { + label = std::make_shared(label, + softmax->get_element_type()); + } + + auto logits_grad = loss_grad_bcast * (softmax - label); + paddle::platform::SetOutputNode(op, "Logits@GRAD", logits_grad, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(softmax_with_cross_entropy, BuildSoftmaxWithCrossEntropyNode); +REGISTER_NG_OP(softmax_with_cross_entropy_grad, + BuildSoftmaxWithCrossEntropyGradNode); diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_softmax_with_cross_entropy_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_softmax_with_cross_entropy_ngraph_op.py new file mode 100644 index 0000000000000000000000000000000000000000..86961b8c366c69a210e47ab5d1ece6ba85d1d262 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_softmax_with_cross_entropy_ngraph_op.py @@ -0,0 +1,20 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from paddle.fluid.tests.unittests.test_softmax_with_cross_entropy_op import TestSoftmaxWithCrossEntropyOp, TestSoftmaxWithCrossEntropyOp2, TestSoftmaxWithCrossEntropyOp3 + +if __name__ == "__main__": + unittest.main()