提交 773dc73f 编写于 作者: L liuhongyu

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_cudnn_5_support

......@@ -194,6 +194,8 @@ paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=Non
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
......
......@@ -48,7 +48,14 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR,
// this is a distributed or inter-process call, find a better way.
#ifdef PADDLE_WITH_CUDA
if (NoDummyInputSize() == 1 &&
local_scopes_[0]->FindLocalVar(NCCL_ID_VARNAME) == nullptr) {
#else
if (NoDummyInputSize() == 1) {
#endif
return; // No need to all reduce when GPU count = 1;
} else {
// Wait input done
......
......@@ -62,6 +62,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
auto multi_devices_pass = AppendPass("multi_devices_pass");
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
multi_devices_pass->Set<int>("num_trainers",
new int(strategy_.num_trainers_));
// Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
......
......@@ -133,6 +133,7 @@ static const char kPlaces[] = "places";
static const char kParams[] = "params";
static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";
static const char kNumTrainers[] = "num_trainers";
void MultiDevSSAGraphBuilder::Init() const {
all_vars_.clear();
......@@ -299,6 +300,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph;
int num_trainers = Get<int>(kNumTrainers);
for (auto &node : nodes) {
if (node->IsVar() && node->Var()) {
all_vars_.emplace(node->Name(), node->Var());
......@@ -383,7 +386,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateComputationalOps(&result, node, places_.size());
}
if (!is_forwarding && places_.size() > 1) {
if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
......@@ -895,4 +898,5 @@ REGISTER_PASS(multi_devices_pass,
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kParams)
.RequirePassAttr(paddle::framework::details::kLocalScopes)
.RequirePassAttr(paddle::framework::details::kStrategy);
.RequirePassAttr(paddle::framework::details::kStrategy)
.RequirePassAttr(paddle::framework::details::kNumTrainers);
......@@ -14,11 +14,13 @@
#include "paddle/fluid/memory/allocation/legacy_allocator.h"
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/memory/detail/buddy_allocator.h"
#include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h"
DEFINE_bool(init_allocated_mem, false,
"It is a mistake that the values of the memory allocated by "
......@@ -110,19 +112,21 @@ size_t Used<platform::CPUPlace>(const platform::CPUPlace &place) {
BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
static std::once_flag init_flag;
static detail::BuddyAllocator **a_arr = nullptr;
static std::vector<int> devices;
std::call_once(init_flag, [gpu_id]() {
int gpu_num = platform::GetCUDADeviceCount();
PADDLE_ENFORCE(gpu_id < gpu_num, "gpu_id:%d should < gpu_num:%d", gpu_id,
gpu_num);
devices = platform::GetSelectedDevices();
int gpu_num = devices.size();
a_arr = new BuddyAllocator *[gpu_num];
for (int i = 0; i < gpu_num; i++) {
for (size_t i = 0; i < devices.size(); ++i) {
int dev_id = devices[i];
a_arr[i] = nullptr;
platform::SetDeviceId(i);
a_arr[i] = new BuddyAllocator(
std::unique_ptr<detail::SystemAllocator>(new detail::GPUAllocator(i)),
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
platform::SetDeviceId(dev_id);
a_arr[i] = new BuddyAllocator(std::unique_ptr<detail::SystemAllocator>(
new detail::GPUAllocator(dev_id)),
platform::GpuMinChunkSize(),
platform::GpuMaxChunkSize());
VLOG(10) << "\n\nNOTE: each GPU device use "
<< FLAGS_fraction_of_gpu_memory_to_use * 100
......@@ -134,7 +138,9 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
});
platform::SetDeviceId(gpu_id);
return a_arr[gpu_id];
auto pos = std::distance(devices.begin(),
std::find(devices.begin(), devices.end(), gpu_id));
return a_arr[pos];
}
#endif
......
......@@ -76,8 +76,8 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
ctx.GetPlace(), layout, library);
framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
library);
}
class ActivationOp : public framework::OperatorWithKernel {
......
......@@ -41,6 +41,12 @@ static std::unordered_set<std::string> InplaceOpSet = {
"floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid",
};
/* The following operator can be used to process SelectedRows, because the
* output of those operator for zero is zero too.
*/
static std::unordered_set<std::string> CanBeUsedBySelectedRows = {
"abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"};
static bool IsInplace(std::string op) { return InplaceOpSet.count(op); }
template <typename DeviceContext, typename Functor>
......@@ -50,16 +56,38 @@ class ActivationKernel
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
auto& X = detail::Ref(context.Input<framework::Tensor>("X"),
"Cannot get input tensor X, variable name = %s",
context.op().Input("X"));
auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"),
"Cannot get output tensor Out, variable name = %s",
context.op().Output("Out"));
Out.mutable_data<T>(context.GetPlace());
auto x_var = context.InputVar("X");
auto out_var = context.OutputVar("Out");
PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input Variable X, variable name = %s",
context.op().Input("X"));
PADDLE_ENFORCE(out_var != nullptr,
"Cannot get output Variable Out, variable name = %s",
context.op().Output("Out"));
framework::Tensor X, *Out;
if (CanBeUsedBySelectedRows.count(context.op().Type())) {
X = detail::Ref(
paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var),
"Cannot get input Tensor X, variable name = %s",
context.op().Input("X"));
Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
out_var);
} else {
X = detail::Ref(context.Input<framework::Tensor>("X"),
"Cannot get input Tensor X, variable name = %s",
context.op().Input("X"));
Out = context.Output<framework::Tensor>("Out");
}
PADDLE_ENFORCE(Out != nullptr,
"Cannot get output tensor Out, variable name = %s",
context.op().Output("Out"));
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(X);
auto out = framework::EigenVector<T>::Flatten(Out);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto* place =
context.template device_context<DeviceContext>().eigen_device();
Functor functor;
......@@ -78,14 +106,54 @@ class ActivationGradKernel
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
auto* Out = context.Input<framework::Tensor>("Out");
auto* dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto out_var = context.InputVar("Out");
auto out_grad_var = context.InputVar(framework::GradVarName("Out"));
auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
PADDLE_ENFORCE(out_var != nullptr,
"Cannot get input Variable Out, variable name = %s",
context.op().Input("Out"));
PADDLE_ENFORCE(out_grad_var != nullptr,
"Cannot get input Variable %s, variable name = %s",
framework::GradVarName("Out"),
context.op().Input(framework::GradVarName("Out")));
PADDLE_ENFORCE(x_grad_var != nullptr,
"Cannot get output Variable %s, variable name = %s",
framework::GradVarName("X"),
context.op().Output(framework::GradVarName("X")));
framework::Tensor Out, dOut, *dX;
if (CanBeUsedBySelectedRows.count(context.op().Type())) {
Out = detail::Ref(
paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var),
"Cannot get input Tensor Out, variable name = %s",
context.op().Input("Out"));
dOut =
detail::Ref(paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
*out_grad_var),
"Cannot get input Tensor %s, variable name = %s",
framework::GradVarName("Out"),
context.op().Input(framework::GradVarName("Out")));
dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
x_grad_var);
} else {
Out = detail::Ref(context.Input<framework::Tensor>("Out"),
"Cannot get input Tensor Out, variable name = %s",
context.op().Input("Out"));
dOut = detail::Ref(
context.Input<framework::Tensor>(framework::GradVarName("Out")),
"Cannot get input Tensor %s, variable name = %s",
framework::GradVarName("Out"),
context.op().Input(framework::GradVarName("Out")));
dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
}
PADDLE_ENFORCE(dX != nullptr,
"Cannot get output tensor %s, variable name = %s",
framework::GradVarName("X"),
context.op().Output(framework::GradVarName("X")));
dX->mutable_data<T>(context.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(*dOut);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto dout = framework::EigenVector<T>::Flatten(dOut);
auto out = framework::EigenVector<T>::Flatten(Out);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto* place =
context.template device_context<DeviceContext>().eigen_device();
......@@ -96,8 +164,19 @@ class ActivationGradKernel
}
bool inplace = functor.Inplace();
if (!inplace) {
auto* X = context.Input<framework::Tensor>("X");
auto x = framework::EigenVector<T>::Flatten(*X);
auto x_var = context.InputVar("X");
PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input tensor X, variable name = %s",
context.op().Input("X"));
framework::Tensor X;
if (CanBeUsedBySelectedRows.count(context.op().Type())) {
X = detail::Ref(
paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var));
} else {
X = detail::Ref(context.Input<framework::Tensor>("X"));
}
auto x = framework::EigenVector<T>::Flatten(X);
functor(*place, x, out, dout, dx);
} else {
VLOG(10) << " Inplace activation ";
......
......@@ -60,15 +60,37 @@ template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input Variable X, variable name = %s",
ctx.op().Input("X"));
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
framework::Tensor x, *z;
if (x_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE(y->dims().size() == 1 && y->dims()[0] == 1,
"For elementwise_op, if X is Sparse, Y must be scalar.");
auto& x_sele = x_var->Get<framework::SelectedRows>();
auto out_sele = ctx.Output<framework::SelectedRows>("Out");
x = x_sele.value();
out_sele->set_rows(x_sele.rows());
out_sele->set_height(x_sele.height());
out_sele->mutable_value()->Resize(x_sele.value().dims());
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
} else if (x_var->IsType<framework::LoDTensor>()) {
x = x_var->Get<framework::LoDTensor>();
z = ctx.Output<framework::LoDTensor>("Out");
} else {
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
x_var->Type().name());
}
z->mutable_data<T>(ctx.GetPlace());
if (x->numel() == y->numel()) {
elementwise_mul<DeviceContext, T>(ctx, x, y, z);
if (x.numel() == y->numel()) {
elementwise_mul<DeviceContext, T>(ctx, &x, y, z);
} else {
default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
default_elementwise_mul<DeviceContext, T>(ctx, &x, y, z);
}
}
};
......
......@@ -40,21 +40,28 @@ class ElementwiseOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of elementwise op should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Y").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Y").front(), ctx->GetInputsVarType("Y").front());
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.");
"The input var's type should be LoDTensor, but the received is %s [%s]",
ctx->GetInputsVarType("Y").front(), ctx->Inputs("Y").front());
if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::LOD_TENSOR) {
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.");
} else if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::SELECTED_ROWS) {
PADDLE_ENFORCE((ctx->GetInputDim("Y").size() == 1u) &&
(ctx->GetInputDim("Y")[0] == 1),
"For elementwise_op, if X is Sparse, "
"Y must be scalar.");
} else {
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
ctx->GetInputsVarType("X").front());
}
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace operators {
class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"GetTensorFromSelectedRowsOp must has input X.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"GetTensorFromSelectedRowsOp must has output Out.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::SELECTED_ROWS,
"The input X's type should be SelectedRows, but the received is %s",
ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front());
PADDLE_ENFORCE(
ctx->GetOutputsVarType("Out").front() ==
framework::proto::VarType::LOD_TENSOR,
"The output Out's type should be LoDTensor, but the received is %s",
ctx->Outputs("Out").front(), ctx->GetOutputsVarType("Out").front());
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.device_context());
}
};
class GetTensorFromSelectedRowsKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *x = ctx.Input<framework::SelectedRows>("X");
auto *out = ctx.Output<framework::LoDTensor>("Out");
out->Resize(x->value().dims());
out->mutable_data(ctx.GetPlace(), x->value().type());
framework::TensorCopy(x->value(), ctx.GetPlace(), ctx.device_context(),
out);
}
};
class GetTensorFromSelectedRowsOpProtoMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input type is SelectedRows.");
AddOutput("Out", "The output type is LoDTensor.");
AddComment(
R"DOC(
GetTensorFromSelectedRows Operator
GetTensorFromSelectedRows is used to get the tensor from SelectedRows.
)DOC");
}
};
class GetTensorFromSelectedRowsOpVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const final {
auto out_var_name = op_desc.Output("Out").front();
auto in_var_name = op_desc.Input("X").front();
auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto in_var = block->FindRecursiveOrCreateVar(in_var_name);
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
out_var.SetDataType(in_var.GetDataType());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(get_tensor_from_selected_rows,
ops::GetTensorFromSelectedRowsOp,
ops::GetTensorFromSelectedRowsOpProtoMaker,
ops::GetTensorFromSelectedRowsOpVarTypeInference);
REGISTER_OP_CPU_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float,
ops::GetTensorFromSelectedRowsKernel, double,
ops::GetTensorFromSelectedRowsKernel, int,
ops::GetTensorFromSelectedRowsKernel, int64_t,
ops::GetTensorFromSelectedRowsKernel);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float,
ops::GetTensorFromSelectedRowsKernel, double,
ops::GetTensorFromSelectedRowsKernel, int,
ops::GetTensorFromSelectedRowsKernel, int64_t,
ops::GetTensorFromSelectedRowsKernel);
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/merge_selected_rows_op.h"
namespace paddle {
namespace operators {
class MergeSelectedRowsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MergeSelectedRowsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MergeSelectedRowsOp should not be null.");
ctx->ShareDim("X", /*->*/ "Out");
}
};
class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input type is SelectedRows, and the selected rows may be "
"duplicated.");
AddOutput("Out",
"The output type is SelectedRows, and the selected rows are not "
"duplicated.");
AddComment(
R"DOC(
MergeSelectedRows Operator.
MergeSelectedRows is used to merge the duplicated rows of the input.
)DOC");
}
};
class MergeSelectedRowsOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(merge_selected_rows, ops::MergeSelectedRowsOp,
ops::MergeSelectedRowsOpMaker,
ops::MergeSelectedRowsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
merge_selected_rows,
ops::MergeSelectedRowsKernel<plat::CPUDeviceContext, float>,
ops::MergeSelectedRowsKernel<plat::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/merge_selected_rows_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
merge_selected_rows,
ops::MergeSelectedRowsKernel<plat::CUDADeviceContext, float>,
ops::MergeSelectedRowsKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class MergeSelectedRowsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::SelectedRows>("X");
auto* out = context.Output<framework::SelectedRows>("Out");
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *x, out);
}
};
} // namespace operators
} // namespace paddle
......@@ -143,7 +143,7 @@ void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer,
case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: {
auto *kernel =
reinterpret_cast<const CUpti_ActivityKernel3 *>(record);
tracer->AddKernelRecords(kernel->start, kernel->end,
tracer->AddKernelRecords(kernel->name, kernel->start, kernel->end,
kernel->deviceId, kernel->streamId,
kernel->correlationId);
break;
......@@ -224,8 +224,9 @@ class DeviceTracerImpl : public DeviceTracer {
stream_id, correlation_id, bytes});
}
void AddKernelRecords(uint64_t start, uint64_t end, int64_t device_id,
int64_t stream_id, uint32_t correlation_id) {
void AddKernelRecords(std::string name, uint64_t start, uint64_t end,
int64_t device_id, int64_t stream_id,
uint32_t correlation_id) {
// 0 means timestamp information could not be collected for the kernel.
if (start == 0 || end == 0) {
VLOG(3) << correlation_id << " cannot be traced";
......@@ -233,7 +234,7 @@ class DeviceTracerImpl : public DeviceTracer {
}
std::lock_guard<std::mutex> l(trace_mu_);
kernel_records_.push_back(
KernelRecord{start, end, device_id, stream_id, correlation_id});
KernelRecord{name, start, end, device_id, stream_id, correlation_id});
}
bool IsEnabled() {
......@@ -276,13 +277,13 @@ class DeviceTracerImpl : public DeviceTracer {
profile_pb.set_start_ns(start_ns_);
profile_pb.set_end_ns(end_ns_);
for (const KernelRecord &r : kernel_records_) {
if (correlations_.find(r.correlation_id) == correlations_.end()) {
fprintf(stderr, "cannot relate a kernel activity\n");
continue;
}
auto *event = profile_pb.add_events();
event->set_type(proto::Event::GPUKernel);
event->set_name(correlations_.at(r.correlation_id));
if (correlations_.find(r.correlation_id) != correlations_.end()) {
event->set_name(correlations_.at(r.correlation_id));
} else {
event->set_name(r.name);
}
event->set_start_ns(r.start_ns);
event->set_end_ns(r.end_ns);
event->set_sub_device_id(r.stream_id);
......
......@@ -39,6 +39,7 @@ inline uint64_t PosixInNsec() {
class DeviceTracer {
public:
struct KernelRecord {
std::string name;
uint64_t start_ns;
uint64_t end_ns;
int64_t device_id;
......@@ -84,8 +85,9 @@ class DeviceTracer {
// Add a cuda kernel stats. `correlation_id` will be mapped to annotation
// added before for human readability.
virtual void AddKernelRecords(uint64_t start, uint64_t end, int64_t device_id,
int64_t stream_id, uint32_t correlation_id) = 0;
virtual void AddKernelRecords(std::string name, uint64_t start, uint64_t end,
int64_t device_id, int64_t stream_id,
uint32_t correlation_id) = 0;
// Generate a proto after done (Disabled).
virtual proto::Profile GenProfile(const std::string& profile_path) = 0;
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/split.h"
#ifndef _WIN32
constexpr static float fraction_of_gpu_memory_to_use = 0.92f;
......@@ -45,6 +46,15 @@ DEFINE_bool(
"input and output must be half precision) and recurrent neural networks "
"(RNNs).");
DEFINE_string(selected_gpus, "",
"A list of device ids separated by comma, like: 0,1,2,3. "
"This option is useful when doing multi process training and "
"each process have only one device (GPU). If you want to use "
"all visible devices, set this to empty string. NOTE: the "
"reason of doing this is that we want to use P2P communication"
"between GPU devices, use CUDA_VISIBLE_DEVICES can only use"
"share-memory only.");
namespace paddle {
namespace platform {
......@@ -121,6 +131,24 @@ int GetCurrentDeviceId() {
return device_id;
}
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices() {
// use user specified GPUs in single-node multi-process mode.
std::vector<int> devices;
if (!FLAGS_selected_gpus.empty()) {
auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
for (auto id : devices_str) {
devices.push_back(atoi(id.c_str()));
}
} else {
int count = GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
devices.push_back(i);
}
}
return devices;
}
void SetDeviceId(int id) {
// TODO(qijun): find a better way to cache the cuda device count
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <cuda_runtime.h>
#include <stddef.h>
#include <string>
#include <vector>
namespace paddle {
namespace platform {
......@@ -47,6 +48,9 @@ int GetCUDAMaxThreadsPerMultiProcessor(int i);
//! Get the current GPU device id in system.
int GetCurrentDeviceId();
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices();
//! Set the GPU device id for next execution.
void SetDeviceId(int device_id);
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/string/split.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
......@@ -82,10 +83,8 @@ void InitDevices(bool init_p2p) {
std::vector<int> devices;
#ifdef PADDLE_WITH_CUDA
try {
int count = platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
devices.push_back(i);
}
// use user specified GPUs in single-node multi-process mode.
devices = platform::GetSelectedDevices();
} catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
}
......@@ -95,20 +94,15 @@ void InitDevices(bool init_p2p) {
void InitDevices(bool init_p2p, const std::vector<int> devices) {
std::vector<platform::Place> places;
int count = 0;
#ifdef PADDLE_WITH_CUDA
try {
count = platform::GetCUDADeviceCount();
} catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
}
#endif
for (size_t i = 0; i < devices.size(); ++i) {
if (devices[i] >= count || devices[i] < 0) {
// In multi process multi gpu mode, we may have gpuid = 7
// but count = 1.
if (devices[i] < 0) {
LOG(WARNING) << "Invalid devices id.";
continue;
}
places.emplace_back(platform::CUDAPlace(devices[i]));
}
if (init_p2p) {
......
......@@ -97,7 +97,7 @@ struct NCCLContextMap {
order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device");
if (places.size() <= 1) {
if (places.size() <= 1 && num_trainers == 1) {
return;
}
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
......@@ -111,12 +111,19 @@ struct NCCLContextMap {
{
int nranks = num_trainers * order_.size();
NCCLGroupGuard gurad;
for (auto &gpu_id : order_) {
int rank = trainer_id * order_.size() + gpu_id;
VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks;
for (size_t i = 0; i < order_.size(); ++i) {
int gpu_id = order_[i];
int rank;
if (order_.size() > 1) {
rank = trainer_id * order_.size() + i;
} else {
rank = trainer_id;
}
VLOG(30) << "init nccl rank: " << rank << " nranks: " << nranks
<< "gpu id: " << gpu_id;
PADDLE_ENFORCE(cudaSetDevice(gpu_id));
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
comms.get() + gpu_id, nranks, *nccl_id, rank));
comms.get() + i, nranks, *nccl_id, rank));
}
}
}
......
......@@ -3,3 +3,4 @@ cc_library(pretty_log SRCS pretty_log.cc)
cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags)
cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags)
cc_test(to_string_test SRCS to_string_test.cc)
cc_test(split_test SRCS split_test.cc)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sstream>
#include <string>
#include <vector>
namespace paddle {
namespace string {
static inline std::vector<std::string> Split(std::string const& original,
char separator) {
std::vector<std::string> results;
std::string token;
std::istringstream is(original);
while (std::getline(is, token, separator)) {
if (!token.empty()) {
results.push_back(token);
}
}
return results;
}
} // namespace string
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/string/split.h"
#include <string>
#include "gtest/gtest.h"
TEST(StringSplit, StringSplit) {
std::string to_split = "0,1,2,3,4,5";
int i = 0;
for (auto s : paddle::string::Split(to_split, ',')) {
EXPECT_EQ(atoi(s.c_str()), i);
i++;
}
}
......@@ -147,7 +147,7 @@ def __bootstrap__():
read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search'
'cudnn_exhaustive_search', 'selected_gpus'
]
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
......
......@@ -271,7 +271,12 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
"All parameters' 'clip_norm' of a same group should be the same"
)
square = grad * grad
merge_grad = grad
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(grad)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
local_norm_var = layers.reduce_sum(input=square)
context[self.group_name].append(local_norm_var)
......@@ -292,6 +297,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
new_grad = layers.elementwise_mul(
x=grad, y=self.context[group_scale_name])
return param, new_grad
......
......@@ -169,6 +169,8 @@ __all__ = [
'log_loss',
'add_position_encoding',
'bilinear_tensor_product',
'merge_selected_rows',
'get_tensor_from_selected_rows',
'lstm',
]
......@@ -8382,6 +8384,29 @@ def mean(x, name=None):
return out
@templatedoc()
def merge_selected_rows(x, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
name(basestring|None): Name of the output.
Returns:
out(${out_type}): ${out_comment}
"""
helper = LayerHelper("merge_selected_rows", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="merge_selected_rows",
inputs={"X": x},
attrs={},
outputs={"Out": out})
return out
@templatedoc()
def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
"""
......@@ -9034,3 +9059,26 @@ def bilinear_tensor_product(x,
# add activation
return helper.append_activation(out)
@templatedoc()
def get_tensor_from_selected_rows(x, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
name(basestring|None): Name of the output.
Returns:
out(${out_type}): ${out_comment}
"""
helper = LayerHelper('get_tensor_from_selected_rows', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='get_tensor_from_selected_rows',
inputs={'X': x},
outputs={'Out': out},
attrs={})
return out
......@@ -95,7 +95,14 @@ class ParallelExecutor(object):
self._places = []
self._act_places = []
if use_cuda:
for i in six.moves.range(core.get_cuda_device_count()):
gpus = []
gpus_env = os.getenv("FLAGS_selected_gpus")
if gpus_env:
gpus = [int(s) for s in gpus_env.split(",")]
else:
for i in six.moves.range(core.get_cuda_device_count()):
gpus.append(i)
for i in gpus:
p = core.Place()
self._act_places.append(core.CUDAPlace(i))
p.set_place(self._act_places[-1])
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
BATCH_SIZE = 128
CLIP = 1
prog = fluid.framework.Program()
with fluid.program_guard(main_program=prog):
image = fluid.layers.data(name='x', shape=[784], dtype='float32')
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
prog_clip = prog.clone()
avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
p_g = fluid.backward.append_backward(loss=avg_cost)
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
with fluid.program_guard(main_program=prog_clip):
fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP))
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
grad_list = [elem[1] for elem in p_g]
grad_clip_list = [elem[1] for elem in p_g_clip]
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
exe.run(fluid.default_startup_program())
count = 0
for data in train_reader():
count += 1
if count > 5:
break
out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
out_clip = exe.run(prog_clip,
feed=feeder.feed(data),
fetch_list=grad_clip_list)
global_norm = 0
for v in out[1:]:
global_norm += np.sum(np.power(v, 2))
global_norm = np.sqrt(global_norm)
global_norm_clip = 0
for v in out_clip[1:]:
global_norm_clip += np.sum(np.power(v, 2))
global_norm_clip = np.sqrt(global_norm_clip)
if not np.isclose(
a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3):
exit(1)
exit(0)
......@@ -43,13 +43,14 @@ if(APPLE)
list(REMOVE_ITEM TEST_OPS test_desc_clone)
list(REMOVE_ITEM TEST_OPS test_program_code)
endif(NOT WITH_DISTRIBUTE)
message(WARNING "These tests has been disabled in OSX before being fixed: \n test_fuse_elewise_add_act_pass \n test_detection_map_op \n test_dist_se_resnext")
message(WARNING "These tests has been disabled in OSX before being fixed: \n test_gradient_clip \n test_fuse_elewise_add_act_pass \n test_detection_map_op \n test_dist_se_resnext")
# this op is not support on mac
list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op)
# TODO: add the unitest back when it fixed
list(REMOVE_ITEM TEST_OPS test_detection_map_op)
list(REMOVE_ITEM TEST_OPS test_dist_se_resnext)
list(REMOVE_ITEM TEST_OPS test_fuse_elewise_add_act_pass)
list(REMOVE_ITEM TEST_OPS test_gradient_clip)
endif()
if(NOT WITH_MKLML)
# this op is not support on openblas
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid.core as core
import numpy as np
from paddle.fluid.op import Operator
class TestGetTensorFromSelectedRows(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def check_with_place(self, place):
scope = core.Scope()
x_rows = [0, 5, 5, 4, 20]
height = 20
row_numel = 2
np_array = np.ones((len(x_rows), row_numel)).astype("float32")
np_array[1, :] = 2.0
np_array[2, :] = 3.0
np_array[3, :] = 4.0
# initialize input variable X
x = scope.var('X').get_selected_rows()
x.set_rows(x_rows)
x.set_height(height)
x_tensor = x.get_tensor()
x_tensor.set(np_array, place)
# initialize input variable Out
out = scope.var("Out").get_tensor()
op = Operator("get_tensor_from_selected_rows", X="X", Out="Out")
op.run(scope, place)
out_array = np.array(out)
self.assertEqual((5, 2), out_array.shape)
assert (out_array == np_array).all()
def test_check_output(self):
for place in self.get_places():
self.check_with_place(place)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
BATCH_SIZE = 128
CLIP = 1
def bow_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
"""
BOW net
This model is from https://github.com/PaddlePaddle/models:
fluid/PaddleNLP/text_classification/nets.py
"""
emb = fluid.layers.embedding(
input=data, is_sparse=True, size=[dict_dim, emb_dim])
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
class TestGradientClip(unittest.TestCase):
def setUp(self):
self.word_dict = paddle.dataset.imdb.word_dict()
self.BATCH_SIZE = 2
self.train_data = paddle.batch(
paddle.dataset.imdb.train(self.word_dict),
batch_size=self.BATCH_SIZE)
def get_places(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def check_operators(self, place):
prog = fluid.framework.Program()
startup_program = fluid.framework.Program()
with fluid.program_guard(
main_program=prog, startup_program=startup_program):
image = fluid.layers.data(name='x', shape=[784], dtype='float32')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
prog_clip = prog.clone()
avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
p_g = fluid.backward.append_backward(loss=avg_cost)
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
with fluid.program_guard(main_program=prog_clip):
fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP))
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
grad_list = [elem[1] for elem in p_g]
grad_clip_list = [elem[1] for elem in p_g_clip]
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=BATCH_SIZE)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
exe.run(startup_program)
count = 0
for data in train_reader():
count += 1
if count > 5:
break
out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
out_clip = exe.run(prog_clip,
feed=feeder.feed(data),
fetch_list=grad_clip_list)
global_norm = 0
for v in out[1:]:
global_norm += np.sum(np.power(v, 2))
global_norm = np.sqrt(global_norm)
global_norm_clip = 0
for v in out_clip[1:]:
global_norm_clip += np.sum(np.power(v, 2))
global_norm_clip = np.sqrt(global_norm_clip)
assert np.isclose(
a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3)
def check_sparse_gradient_clip(self, place):
prog = fluid.framework.Program()
startup_program = fluid.framework.Program()
with fluid.program_guard(
main_program=prog, startup_program=startup_program):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
cost = bow_net(data, label, len(self.word_dict))
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(cost)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
exe.run(startup_program)
data = next(self.train_data())
val = exe.run(prog, feed=feeder.feed(data), fetch_list=[cost])[0]
self.assertEqual((1, ), val.shape)
print(val)
self.assertFalse(np.isnan(val))
def test_operators(self):
self.check_operators(core.CPUPlace())
def test_sparse_gradient_clip(self):
for place in self.get_places():
self.check_sparse_gradient_clip(place)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid.core as core
import numpy as np
from paddle.fluid.op import Operator
class TestMergeSelectedRows(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def check_with_place(self, place):
scope = core.Scope()
x_rows = [0, 5, 5, 4, 20]
out_rows = [0, 4, 5, 20]
height = 20
row_numel = 2
np_array = np.ones((len(x_rows), row_numel)).astype("float32")
np_array[1, :] = 2.0
np_array[2, :] = 3.0
np_array[3, :] = 4.0
# initialize input variable X
x = scope.var('X').get_selected_rows()
x.set_rows(x_rows)
x.set_height(height)
x_tensor = x.get_tensor()
x_tensor.set(np_array, place)
# initialize input variable Out
out = scope.var("Out").get_selected_rows()
op = Operator("merge_selected_rows", X="X", Out="Out")
op.run(scope, place)
self.assertEqual(out.rows(), out_rows)
self.assertEqual(out.height(), height)
out_array = np.array(out.get_tensor())
self.assertEqual((4, 2), out_array.shape)
assert (out_array[0, :] == 1.0).all()
assert (out_array[1, :] == 4.0).all()
assert (out_array[2, :] == 5.0).all()
assert (out_array[3, :] == 1.0).all()
def test_check_output(self):
for place in self.get_places():
self.check_with_place(place)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册