未验证 提交 ee2028a1 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add use_cuda to inplace pass (#17205)

* add use_cuda to inplace pass,test=develop

* add test softmax_with_xe_inplace test,test=develop
上级 f2db475a
......@@ -311,6 +311,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "inplace_pass") {
pass->Erase(kUseCuda);
pass->Set<bool>(kUseCuda, new bool(use_cuda));
}
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph);
......
......@@ -33,6 +33,19 @@ namespace details {
using OpToVarNameSetMap =
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
static std::map<size_t, std::unordered_set<std::string>> VarsGroupByScopeIdx(
const OpToVarNameSetMap &map) {
std::map<size_t, std::unordered_set<std::string>> result;
for (auto &pair : map) {
size_t scope_idx = pair.first->GetScopeIdx();
auto &var_set = result[scope_idx];
for (auto &var : pair.second) {
var_set.insert(var);
}
}
return result;
}
// Check whether the variable is LoDTensor based on static VarDesc info
static bool IsLoDTensor(VarDesc *var) {
return var->Proto()->type().type() == proto::VarType::LOD_TENSOR;
......@@ -236,6 +249,14 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = " << memory_fraction;
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
if (VLOG_IS_ON(10)) {
auto vars_group_by_scope_idx = VarsGroupByScopeIdx(op_vars_map);
for (auto &pair : vars_group_by_scope_idx) {
VLOG(10) << "Scope " << pair.first << " has " << pair.second.size()
<< " vars";
}
}
auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
while_op_eager_deletion_pass->Apply(graph);
......
......@@ -111,9 +111,9 @@ class InplacePass : public ir::Pass {
// Check whether all `ops` is the preceding ops of `op`
bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const;
// Find node whose name is equal to the given name
static ir::Node *FindNodeByName(const std::string &name,
const std::vector<ir::Node *> &nodes);
// Find nodes whose name are equal to the given name
static std::unordered_set<ir::Node *> FindNodesByName(
const std::string &name, const std::vector<ir::Node *> &nodes);
// Get all versions vars named var_name
std::vector<ir::Node *> *AllVersionVars(const std::string &var_name) const;
......@@ -290,17 +290,15 @@ void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var,
op->Op()->Flush();
}
ir::Node *InplacePass::FindNodeByName(const std::string &name,
const std::vector<ir::Node *> &nodes) {
ir::Node *found_node = nullptr;
std::unordered_set<ir::Node *> InplacePass::FindNodesByName(
const std::string &name, const std::vector<ir::Node *> &nodes) {
std::unordered_set<ir::Node *> ret;
for (auto *node : nodes) {
if (node->Name() == name) {
PADDLE_ENFORCE(found_node == nullptr, "Find duplicate input nodes %s",
name);
found_node = node;
ret.insert(node);
}
}
return found_node;
return ret;
}
void InplacePass::ApplyImpl(ir::Graph *graph) const {
......@@ -326,6 +324,10 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
}
// Step 3: traverse ops and try inplace if possible
bool use_cuda = Get<bool>(kUseCuda);
VLOG(4) << "Inplace pass is applied when use_cuda = "
<< (use_cuda ? "true" : "false");
for (auto *op_node : ops) {
PADDLE_ENFORCE_NOT_NULL(op_node->Op(), "op_desc is nullptr");
......@@ -343,7 +345,7 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue;
}
auto in_to_outs = infer_inplace(*op_desc);
auto in_to_outs = infer_inplace(*op_desc, use_cuda);
for (auto &pair : in_to_outs) {
auto &in_param = pair.first;
auto &out_param = pair.second;
......@@ -385,9 +387,17 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue;
}
auto *in_node = FindNodeByName(in_arg, op_node->inputs);
PADDLE_ENFORCE_NOT_NULL(in_node, "Input(%s)=%s cannot be found in op %s",
in_param, in_arg, op_type);
auto in_nodes = FindNodesByName(in_arg, op_node->inputs);
PADDLE_ENFORCE(!in_nodes.empty(), "Input(%s)=%s cannot be found in op %s",
in_param, in_arg, op_type);
if (in_nodes.size() > 1) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " occurs in other inputs of " << op_type;
continue;
}
auto *in_node = *in_nodes.begin();
if (!NodeCanReused(in_node)) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
......@@ -410,10 +420,29 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue;
}
auto *out_node = FindNodeByName(out_arg, op_node->outputs);
PADDLE_ENFORCE_NOT_NULL(out_node,
"Output(%s)=%s cannot be found in op %s",
out_param, out_arg, op_type);
auto out_nodes = FindNodesByName(out_arg, op_node->outputs);
PADDLE_ENFORCE(!out_nodes.empty(),
"Output(%s)=%s cannot be found in op %s", out_param,
out_arg, op_type);
PADDLE_ENFORCE_EQ(
out_nodes.size(), 1,
"Wrong graph: Output(%s)=%s occurs in other outputs of op %s",
out_param, out_arg, op_type);
if (!FindNodesByName(in_arg, op_node->outputs).empty()) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " occurs in output of op " << op_type;
continue;
}
if (!FindNodesByName(out_arg, op_node->inputs).empty()) {
VLOG(4) << "Cannot inplace because Output(" << in_param
<< ")=" << out_arg << " occurs in input of op " << op_type;
continue;
}
auto *out_node = *out_nodes.begin();
if (!NodeCanReused(out_node)) {
VLOG(4) << "Cannot inplace because Output(" << out_param
......@@ -457,4 +486,5 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
} // namespace framework
} // namespace paddle
REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass);
REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass)
.RequirePassAttr(paddle::framework::details::kUseCuda);
......@@ -36,6 +36,8 @@ namespace details {
constexpr char kMemOptSkipVars[] = "@MEM_OPT_SKIP_VARS@";
typedef std::unordered_set<std::string> MemOptSkipVars;
constexpr char kUseCuda[] = "use_cuda";
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
// NOTE(dzh): A ordered set for node reuse in memory optimize.
......
......@@ -214,9 +214,9 @@ struct OpInfoFiller<T, kShapeInference> {
template <typename T>
struct OpInfoFiller<T, kInplaceOpInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_inplace_ = [](const OpDesc& op_desc) {
info->infer_inplace_ = [](const OpDesc& op_desc, bool use_cuda) {
T infer;
return infer(op_desc);
return infer(op_desc, use_cuda);
};
}
};
......
......@@ -37,7 +37,7 @@ class InplaceOpInference {
public:
virtual ~InplaceOpInference() {}
virtual std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc) const = 0;
const OpDesc& op_desc, bool use_cuda) const = 0;
};
/*
......@@ -47,7 +47,7 @@ class InplaceOpInference {
class SingleOpInplaceInToOut : public InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc) const override {
const OpDesc& op_desc, bool use_cuda) const override {
PADDLE_ENFORCE(!op_desc.InputNames().empty(),
"Op inputs must not be empty");
PADDLE_ENFORCE(!op_desc.OutputNames().empty(),
......@@ -65,7 +65,7 @@ class SingleOpInplaceInToOut : public InplaceOpInference {
class GradOpInplaceInToOut : public InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc) const override {
const OpDesc& op_desc, bool use_cuda) const override {
std::unordered_map<std::string, std::string> ret;
std::unordered_set<std::string> output_names(op_desc.OutputNames().begin(),
op_desc.OutputNames().end());
......
......@@ -32,7 +32,9 @@ namespace paddle {
namespace framework {
std::unique_ptr<ir::Pass> CreateInplacePass() {
return ir::PassRegistry::Instance().Get("inplace_pass");
auto pass = ir::PassRegistry::Instance().Get("inplace_pass");
pass->Set<bool>(details::kUseCuda, new bool(true));
return pass;
}
class NOP : public OperatorBase {
......@@ -141,7 +143,7 @@ class MultiOutGradShapeInference : public framework::InferShapeBase {
class MultiOutInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc) const override {
const OpDesc& op_desc, bool use_cuda) const override {
return std::unordered_map<std::string, std::string>{
{"X", "Out"}, {"Y", "YOut"}, {"Z", "ZOut"},
};
......@@ -151,7 +153,7 @@ class MultiOutInplaceInToOut : public framework::InplaceOpInference {
class MultiOutGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc) const override {
const OpDesc& op_desc, bool use_cuda) const override {
return std::unordered_map<std::string, std::string>{
{framework::GradVarName("YOut"), framework::GradVarName("Y")},
{framework::GradVarName("Out"), framework::GradVarName("X")},
......
......@@ -60,7 +60,7 @@ using InferVarTypeFN =
using InferShapeFN = std::function<void(InferShapeContext*)>;
using InplacePair = std::unordered_map<std::string, std::string>;
using InferInplaceOpFN = std::function<InplacePair(const OpDesc&)>;
using InferInplaceOpFN = std::function<InplacePair(const OpDesc&, bool)>;
using InferNoNeedBufferVarsFN = std::function<std::unordered_set<std::string>(
const VariableNameMap& /*inputs*/, const VariableNameMap& /*outputs*/,
......
......@@ -600,25 +600,21 @@ std::unique_ptr<framework::OpDesc> BatchNormGradMaker::Apply() const {
class BatchNormInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"},
};
return inplace_in_to_out;
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"}};
}
};
class BatchNormGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
// Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C]
const framework::OpDesc &op_desc, bool use_cuda) const override {
// Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C]
return {
{framework::GradVarName("Y"), framework::GradVarName("X")},
{"SavedMean", framework::GradVarName("Scale")},
{"SavedVariance", framework::GradVarName("Bias")},
};
return inplace_in_to_out;
}
};
......
......@@ -255,20 +255,16 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
class ElementwiseOpInplace : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc) const override {
return std::unordered_map<std::string, std::string>{
{"X", "Out"},
};
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"X", "Out"}};
}
};
class ElementwiseGradOpInplace : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc) const override {
return std::unordered_map<std::string, std::string>{
{framework::GradVarName("Out"), framework::GradVarName("X")},
};
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
......
......@@ -270,22 +270,16 @@ class Flatten2GradOp : public framework::OperatorBase {
class FlattenOpInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{"X", "Out"},
};
return inplace_in_to_out;
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"X", "Out"}};
}
};
class FlattenGradInplaceinToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{framework::GradVarName("Out"), framework::GradVarName("X")},
};
return inplace_in_to_out;
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
......
......@@ -173,7 +173,7 @@ class GroupNormGradMaker : public framework::SingleGradOpDescMaker {
class GroupNormInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc) const override {
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"X", "Y"}};
}
};
......@@ -181,7 +181,7 @@ class GroupNormInplaceInToOut : public framework::InplaceOpInference {
class GroupNormGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc) const override {
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Y"), framework::GradVarName("X")}};
}
};
......
......@@ -325,22 +325,16 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
class ReshapeOpInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{"X", "Out"},
};
return inplace_in_to_out;
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{"X", "Out"}};
}
};
class ReshapeGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc &op_desc) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{framework::GradVarName("Out"), framework::GradVarName("X")},
};
return inplace_in_to_out;
const framework::OpDesc &op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
......
......@@ -228,11 +228,24 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
}
};
class SoftmaxWithCrossEntropyInplaceInference
: public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const {
if (use_cuda && !boost::get<bool>(op_desc.GetAttr("soft_label"))) {
return {{"Logits", "Softmax"}};
} else {
return {};
}
}
};
class SoftmaxWithCrossEntropyGradInplaceInference
: public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc) const {
const framework::OpDesc& op_desc, bool use_cuda) const {
return {{"Softmax", framework::GradVarName("Logits")}};
}
};
......@@ -243,7 +256,8 @@ class SoftmaxWithCrossEntropyGradInplaceInference
namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker);
ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker,
ops::SoftmaxWithCrossEntropyInplaceInference);
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad,
ops::SoftmaxWithCrossEntropyGradInplaceInference);
......
......@@ -183,8 +183,7 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
// Make sure that BlockDim <= feature_size
template <typename T, int BlockDim>
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
int feature_size) {
const T* labels_data, T* loss_data, T* softmax, int feature_size) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
......@@ -210,11 +209,9 @@ static __global__ void RowReductionForSoftmaxAndCrossEntropy(
template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctor {
public:
HardLabelSoftmaxWithCrossEntropyFunctor(const T* logits,
const int64_t* labels, T* loss,
HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
T* log_softmax, int feature_size)
: logits_(logits),
labels_(labels),
: labels_(labels),
loss_(loss),
log_softmax_(log_softmax),
feature_size_(feature_size) {}
......@@ -232,7 +229,6 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
}
private:
const T* logits_;
const int64_t* labels_;
T* loss_;
T* log_softmax_;
......@@ -242,13 +238,11 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
public:
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const T* logits,
const int64_t* labels,
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
T* loss, T* log_softmax,
int feature_size,
int ignore_idx)
: logits_(logits),
labels_(labels),
: labels_(labels),
loss_(loss),
log_softmax_(log_softmax),
feature_size_(feature_size),
......@@ -267,7 +261,6 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
}
private:
const T* logits_;
const int64_t* labels_;
T* loss_;
T* log_softmax_;
......@@ -293,23 +286,22 @@ static void HardLabelSoftmaxWithCrossEntropy(
: (1 << static_cast<int>(std::log2(feature_size)));
auto stream = ctx.stream();
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, feature_size); \
RowReductionForDiffMaxSum<T, BlockDim, \
true><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, softmax_data, feature_size); \
platform::ForRange<platform::CUDADeviceContext> for_range( \
ctx, batch_size* feature_size); \
if (ignore_idx >= 0 && ignore_idx < feature_size) { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
logits_data, labels_data, loss_data, softmax_data, feature_size, \
ignore_idx)); \
} else { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
logits_data, labels_data, loss_data, softmax_data, feature_size)); \
} \
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, feature_size); \
RowReductionForDiffMaxSum<T, BlockDim, \
true><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, softmax_data, feature_size); \
platform::ForRange<platform::CUDADeviceContext> for_range( \
ctx, batch_size* feature_size); \
if (ignore_idx >= 0 && ignore_idx < feature_size) { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
labels_data, loss_data, softmax_data, feature_size, ignore_idx)); \
} else { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
labels_data, loss_data, softmax_data, feature_size)); \
} \
} break
switch (block_dim) {
......@@ -356,7 +348,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
logits_data, loss_data, softmax_data, feature_size); \
RowReductionForSoftmaxAndCrossEntropy< \
T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, labels_data, loss_data, softmax_data, feature_size); \
labels_data, loss_data, softmax_data, feature_size); \
break
switch (block_dim) {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/var_type_inference.h"
......@@ -237,13 +238,21 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
}
};
class SumInplace : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const override {
return {{"X", "Out"}};
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
ops::SumOpVarTypeInference);
ops::SumOpVarTypeInference, ops::SumInplace);
REGISTER_OP_CPU_KERNEL(
sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.fluid import layers
import numpy as np
import unittest
class TestSoftmaxWithXe(unittest.TestCase):
def setUp(self):
self.m, self.n = np.random.random_integers(
low=100, high=2000, size=[2]).astype('int64')
def softmax_with_xe(self, x, y, place, inplace=True):
m, n = x.shape
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
x_d = fluid.layers.data(
name='x',
shape=[m, n],
dtype='float32',
append_batch_size=False)
y_d = fluid.layers.data(
name='y',
shape=[m, 1],
dtype='int64',
append_batch_size=False)
z_d, s_d = fluid.layers.softmax_with_cross_entropy(
x_d, y_d, return_softmax=True)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = inplace
prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
build_strategy=build_strategy, places=place)
if inplace and isinstance(place, fluid.CUDAPlace):
fetch_list = [z_d.name, x_d.name]
else:
fetch_list = [z_d.name, s_d.name]
z, s = exe.run(prog,
feed={x_d.name: x,
y_d.name: y},
fetch_list=fetch_list)
return z, s
def main_with_place(self, place):
x = np.random.random(size=[self.m, self.n]).astype('float32')
x_range = [(-30, 30), (10, 20), (-1, 1), (2, 3), (0, 0.3), (-200, -100)]
for a, b in x_range:
x = ((b - a) * x + a).astype('float32')
y = np.random.random_integers(
size=[self.m, 1], low=0, high=self.n - 1).astype('int64')
z1, s1 = self.softmax_with_xe(x, y, place, False)
z2, s2 = self.softmax_with_xe(x, y, place, True)
self.assertTrue((z1 == z2).all())
self.assertTrue((s1 == s2).all())
def test_main(self):
self.main_with_place(fluid.CPUPlace())
if fluid.core.is_compiled_with_cuda():
self.main_with_place(fluid.CUDAPlace(0))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册