提交 0a0f1948 编写于 作者: Z zchen0211

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop

......@@ -110,7 +110,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
dup_output_ops[out].emplace_back(local_op_id);
return false;
});
net->AddOp(std::move(bwd));
net->AppendOp(std::move(bwd));
}
// Get unique ID for this method.
auto uid = uniq_id++;
......@@ -163,8 +163,9 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {{"Src", {prefix}}},
{{"Dst", {grad_input}}}, {}));
net->AppendOp(OpRegistry::CreateOp("fill_zeros_like",
{{"Src", {prefix}}},
{{"Dst", {grad_input}}}, {}));
}
return false;
});
......@@ -195,7 +196,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
}
net->AddOp(std::move(grad_op));
net->AppendOp(std::move(grad_op));
}
net->SetType("@GENERATED_BACKWARD@");
net->CompleteAddOp();
......
......@@ -75,13 +75,13 @@ class FcOp : public operators::NetOp {
FcOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs, const AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
AddOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
{{"Out", {Output("mul_result")}}}, {}));
AppendOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
{{"Out", {Output("mul_result")}}}, {}));
auto input_b = Inputs("b");
std::string before_act = "mul_result";
if (input_b.size() != 0) {
AddOp(OpRegistry::CreateOp(
AppendOp(OpRegistry::CreateOp(
"rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}},
{{"Out", {Output("add_result")}}}, {}));
before_act = "add_result";
......@@ -92,8 +92,8 @@ class FcOp : public operators::NetOp {
}
}
AddOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}},
{{"Out", {Output("Out")}}}, {}));
AppendOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}},
{{"Out", {Output("Out")}}}, {}));
CompleteAddOp(false);
}
};
......@@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) {
TEST(Backward, net_input_of_network_not_need_grad) {
ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp(
net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}},
{{"mul_result", {"mul_tmp_0"}},
{"add_result", {"add_tmp_0"}},
{"Out", {"hidden0"}}},
{}));
net.AddOp(f::OpRegistry::CreateOp(
net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}},
{{"mul_result", {"mul_tmp_1"}},
{"add_result", {"add_tmp_1"}},
......@@ -273,10 +273,10 @@ TEST(Backward, net_input_of_network_not_need_grad) {
TEST(Backward, net_shared_weight) {
ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}},
{{"Out", {"out"}}}, {}));
net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}},
{{"Out", {"FinalOut"}}}, {}));
net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}},
{{"Out", {"out"}}}, {}));
net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}},
{{"Out", {"FinalOut"}}}, {}));
net.CompleteAddOp();
auto bwd = f::Backward(net, {});
......@@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) {
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp(
net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"mul_result", {"mul_out1"}},
{"add_result", {"add_out1"}},
{"Out", {"out1"}}},
{}));
net.AddOp(f::OpRegistry::CreateOp(
net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"mul_result", {"mul_out2"}},
{"add_result", {"tmp_out2"}},
{"Out", {"out2"}}},
{}));
net.AddOp(f::OpRegistry::CreateOp(
net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}},
{{"mul_result", {"mul_out3"}},
{"add_result", {"tmp_out3"}},
......
......@@ -31,7 +31,7 @@ limitations under the License. */
namespace py = pybind11;
USE_OP(add_two);
USE_CPU_ONLY_OP(onehot_cross_entropy);
USE_OP(onehot_cross_entropy);
USE_OP(sgd);
USE_OP(mul);
USE_OP(mean);
......@@ -223,8 +223,8 @@ All parameter, weight, gradient are variables in Paddle.
retv->SetType("plain_net");
return retv;
})
.def("add_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AddOp(op); })
.def("append_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AppendOp(op); })
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
......
......@@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector<Argument>& inArgs) {
auto mat = dynamic_cast<SparsePrefetchRowCpuMatrix*>(
para->getMat(PARAMETER_VALUE).get());
para->clearGradient();
mat->clearIndices();
if (mat) mat->clearIndices();
}
}
}
......
......@@ -39,11 +39,10 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
// TODO(superjom) add enforce here after helper functions ready
X_grad->Resize(X->dims());
dX->Resize(X->dims());
}
};
......@@ -70,9 +69,7 @@ namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<float>);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<float>);
......@@ -12,10 +12,122 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/assert.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__host__ __device__ T clipping_log(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
T v = log(x);
if (v == INFINITY) {
return kApproInf;
}
if (v == -INFINITY) {
return -kApproInf;
}
return v;
}
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
const int N, const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -clipping_log(X[i * D + label[i]]);
}
}
// TODO(qingqing): make zero setting an common function.
template <typename T>
__global__ void zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
X[i] = 0.0;
}
}
template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const int* label, const int N,
const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
int idx = i * D + label[i];
dX[idx] = -dY[i] / X[idx];
}
}
template <typename T>
class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>("label")->data<int>();
auto Y = ctx.Output<Tensor>("Y");
Y->mutable_data<T>(ctx.GetPlace());
T* Ydata = Y->data<T>();
int N = X->dims()[0];
int D = X->dims()[1];
int block = 512;
int grid = (N + block - 1) / block;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyKernel<T><<<grid, block>>>(Ydata, Xdata, label_data, N, D);
}
};
template <typename T>
class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto label = ctx.Input<Tensor>("label");
auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
auto* dYdata = dY->template data<T>();
auto* Xdata = X->template data<T>();
auto* label_data = label->data<int>();
int N = X->dims()[0];
int D = X->dims()[1];
int block = 512;
int grid = (N * D + block - 1) / block;
zero<T><<<grid, block>>>(dXdata, N * D);
grid = (N + block - 1) / block;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyGradientKernel<T><<<grid, block>>>(dXdata, dYdata, Xdata,
label_data, N, D);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpCUDAKernel<float>);
......@@ -21,7 +21,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
T tolerable_value(T x) {
inline T tolerable_value(const T x) {
static_assert(std::is_floating_point<T>::value,
"tolerable_value works only on float, "
"double and double double.");
......@@ -39,10 +39,13 @@ T tolerable_value(T x) {
return x;
}
template <typename Place, typename T>
template <typename T>
class OnehotCrossEntropyOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>("label")->data<int>();
......@@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel {
}
};
template <typename Place, typename T>
template <typename T>
class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
......@@ -79,6 +85,8 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
const int batch_size = X->dims()[0];
const int class_num = X->dims()[1];
// TODO(qingqing): make zero setting an common function.
memset(dXdata, 0, sizeof(T) * batch_size * class_num);
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -19,25 +16,25 @@ namespace paddle {
namespace operators {
template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
class CPUGaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
// TODO(dzh): attribute does not support unsigned int.
// And we need a global random seed configuration.
int seed = context.op_.GetAttr<int>("seed");
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
std::mt19937 g(seed);
std::normal_distribution<T> distribution(mean, std);
engine.seed(seed);
std::normal_distribution<T> dist(mean, std);
ssize_t size = framework::product(tensor->dims());
for (int i = 0; i < size; ++i) {
data[i] = distribution(g);
for (ssize_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
}
};
......@@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
auto* tensor = context.Output<framework::Tensor>("Out");
auto dims = GetAttr<std::vector<int>>("dims");
PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set.");
......@@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator.
)DOC");
AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0f);
AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f);
AddAttr<float>("mean", "mean of random tensor.").SetDefault(.0f);
AddAttr<float>("std", "std of random tensor.").SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed of generator."
"0 means use system wide seed")
......@@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator.
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <random>
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
struct GaussianGenerator {
T mean_, std_;
unsigned int seed_;
__host__ __device__ GaussianGenerator(T mean, T std, int seed)
: mean_(mean), std_(std), seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::normal_distribution<T> dist(mean_, std_);
rng.discard(n);
return dist(rng);
}
};
template <typename T>
class GPUGaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
int seed = context.op_.GetAttr<int>("seed");
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
if (seed == 0) {
std::random_device rd;
seed = rd();
}
curandGenerator_t g;
PADDLE_ENFORCE(platform::dynload::curandCreateGenerator(
&g, CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
platform::dynload::curandGenerateNormal(
g, data, framework::product(tensor->dims()), mean, std);
T mean = static_cast<T>(context.op_.GetAttr<float>("mean"));
T std = static_cast<T>(context.op_.GetAttr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
REGISTER_OP_GPU_KERNEL(gaussian_random,
paddle::operators::GPUGaussianRandomKernel<float>);
......@@ -13,7 +13,6 @@
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
......
......@@ -84,13 +84,14 @@ class NetOp : public framework::OperatorBase {
return true;
}
void AddOp(const framework::OperatorBase& op) { AddOp(op.Clone()); }
void AppendOp(const framework::OperatorBase& op) { AppendOp(op.Clone()); }
/**
* @brief Add an operator by ptr
*/
void AddOp(std::unique_ptr<framework::OperatorBase> op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
void AppendOp(std::unique_ptr<framework::OperatorBase> op) {
PADDLE_ENFORCE(!add_op_done_,
"Cannot AppendOp when this network is sealed");
PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
ops_.push_back(std::move(op));
}
......
......@@ -38,10 +38,10 @@ TEST(OpKernel, all) {
auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr);
net->AddOp(std::unique_ptr<TestOp>(
net->AppendOp(std::unique_ptr<TestOp>(
new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {})));
net->AddOp(std::unique_ptr<TestOp>(
net->AppendOp(std::unique_ptr<TestOp>(
new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"Out", {"z"}}}, {})));
......@@ -61,7 +61,7 @@ TEST(NetOp, insert_op) {
auto op1 = std::unique_ptr<framework::NOP>(
new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {}));
net.AddOp(*op1);
net.AppendOp(*op1);
net.InsertOp(0, *op1);
ASSERT_EQ(2UL, net.ops_.size());
net.InsertOp(2, std::move(op1));
......@@ -70,9 +70,9 @@ TEST(NetOp, insert_op) {
TEST(NetOp, Clone) {
NetOp net;
net.AddOp(
net.AppendOp(
std::unique_ptr<framework::NOP>(new framework::NOP{"empty", {}, {}, {}}));
net.AddOp(std::unique_ptr<framework::NOP>(
net.AppendOp(std::unique_ptr<framework::NOP>(
new framework::NOP{"empty2", {}, {}, {}}));
net.CompleteAddOp(true);
auto new_net_op = net.Clone();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel {
std::uniform_real_distribution<T> dist(
static_cast<T>(context.op_.GetAttr<float>("min")),
static_cast<T>(context.op_.GetAttr<float>("max")));
for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) {
ssize_t size = framework::product(tensor->dims());
for (ssize_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
}
......@@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The output tensor of uniform random op");
AddComment(R"DOC(Uniform random operator.
Used to initialize tensor with uniform random generator.
)DOC");
AddAttr<std::vector<int>>("dims", "the dimension of random tensor");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
......@@ -65,7 +65,10 @@ public:
size_t getSize() const { return config_.size(); }
bool isFullSize() const {
return this->getSize() == bufs_[PARAMETER_VALUE]->getSize();
if (bufs_[PARAMETER_VALUE]) {
return this->getSize() == bufs_[PARAMETER_VALUE]->getSize();
}
return false;
}
inline bool useGpu() const { return useGpu_; }
......
......@@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
if (curand_generator_) {
PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_));
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
......@@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
cudaStream_t CUDADeviceContext::stream() { return stream_; }
curandGenerator_t CUDADeviceContext::curand_generator() {
if (!curand_generator_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
}
return curand_generator_;
}
#endif // PADDLE_ONLY_CPU
} // namespace platform
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
......@@ -40,7 +39,7 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext {
public:
CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace);
explicit CPUDeviceContext(CPUPlace place);
virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const;
......@@ -56,7 +55,7 @@ class EigenCudaStreamDevice;
class CUDADeviceContext : public DeviceContext {
public:
explicit CUDADeviceContext(GPUPlace);
explicit CUDADeviceContext(GPUPlace place);
virtual ~CUDADeviceContext();
/*! \brief Wait for all operations completion in the stream. */
......@@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle();
/*! \brief Return curand handle in the device context. */
curandGenerator_t curand_generator();
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream();
// clang-format on
......@@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext {
private:
GPUPlace place_;
private:
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
private:
uint64_t seed_;
// clang-format off
cudaStream_t stream_{nullptr};
cudnnHandle_t cudnn_handle_{nullptr};
cublasHandle_t cublas_handle_{nullptr};
curandGenerator_t curand_generator_{nullptr};
// clang-format on
};
......
......@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
curandGenerator_t curand_handle = device_context->curand_generator();
ASSERT_NE(nullptr, curand_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context;
}
......
......@@ -65,7 +65,6 @@ void ParameterClient2::initThreads() {
LOG(INFO) << "parallel_thread_num dosent need to set";
}
syncThreadPool_.reset(new SyncThreadPool(threadNum_));
startThreads();
}
......@@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData(
request.set_cost(cost);
request.set_batch_status(batchStatus);
CHECK_EQ(request.blocks_size(), 0);
VLOG(10) << "request: trainer_id: " << request.trainer_id()
<< " update_mode" << request.update_mode()
<< " send_back_parameter: " << request.send_back_parameter()
<< " send_back_parameter_type: "
<< request.send_back_parameter_type()
<< " num_samples: " << request.num_samples()
<< " cost: " << request.cost()
<< " batch_status: " << request.batch_status();
}
for (const auto& segments : parameterSegments) {
const auto it = parameterMap_.find(segments.id);
......@@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData(
CHECK(sendMat != nullptr) << "sendMat is nullptr";
syncThreadPool_->exec([&](int tid, size_t numThreads) {
std::lock_guard<std::mutex> guard(sparseAutoGrowthMutex_);
const auto& localIndices = prefetchMat->getLocalIndices();
/// num of sparse rows
size_t nLocalBlocks = localIndices.size();
uint64_t beginDim = 0;
uint64_t endDim = 0;
// FIXME(typhoonzero): let it resize first
prefetchMat->getLocalRow(nLocalBlocks + 1);
sendMat->getLocalRow(nLocalBlocks + 1);
for (size_t row = 0; row < nLocalBlocks; ++row) {
int64_t blockId = localIndices[row]; // local row -> sparse row
int serverId = std::abs((blockId + nameHash) % serviceNum_);
......@@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData(
block->set_begin_pos(row * blockSize);
/// block len
block->set_block_size(endDim - beginDim);
if (sendingPara) {
sendJob->parallelInputIovs[serverId].push_back(
{sendMat->getLocalRow(row), sizeof(real) * (size_t)blockSize});
......
......@@ -583,6 +583,7 @@ protected:
#ifndef PADDLE_DISABLE_TIMER
uint64_t forwardbackwordTime_;
#endif
std::mutex sparseAutoGrowthMutex_;
/// map id to parameter used for decoding protobuf data
std::unordered_map<size_t, ParameterPtr> parameterMap_;
......
......@@ -23,7 +23,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_operator SRCS test_operator.py)
# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py)
......
......@@ -64,7 +64,8 @@ class OpTestMeta(type):
actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name]
self.assertTrue(
numpy.allclose(actual, expect),
numpy.allclose(
actual, expect, atol=1e-05),
"output name: " + out_name + "has diff")
obj.test_all = test_all
......
......@@ -8,9 +8,8 @@ class TestCrossEntropy(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
# TODO this unit test is not passed
self.type = "onehot_cross_entropy"
batch_size = 100
batch_size = 30
class_num = 10
X = numpy.random.random((batch_size, class_num)).astype("float32")
label = 5 * numpy.ones(batch_size).astype("int32")
......@@ -22,9 +21,9 @@ class TestCrossEntropy(unittest.TestCase):
class CrossEntropyGradOpTest(GradientChecker):
def test_softmax_grad(self):
def test_check_grad(self):
op = create_op("onehot_cross_entropy")
batch_size = 100
batch_size = 30
class_num = 10
inputs = {
"X": numpy.random.uniform(
......
......@@ -6,8 +6,8 @@ import unittest
def fc(X, W, Y):
ret_v = core.Net.create()
ret_v.add_op(Operator("mul", X="X", Y="W", Out="pre_activation"))
ret_v.add_op(Operator("sigmoid", X="pre_activation", Y=Y))
ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation"))
ret_v.append_op(Operator("sigmoid", X="pre_activation", Y=Y))
ret_v.complete_add_op(True)
return ret_v
......@@ -16,12 +16,12 @@ class TestNet(unittest.TestCase):
def test_net_all(self):
net = core.Net.create()
op1 = Operator("add_two", X="X", Y="Y", Out="Out")
net.add_op(op1)
net.append_op(op1)
net2 = core.Net.create()
net2.add_op(fc(X="X", W="w", Y="fc.out"))
net2.append_op(fc(X="X", W="w", Y="fc.out"))
net2.complete_add_op(True)
net.add_op(net2)
net.append_op(net2)
net.complete_add_op(True)
expected = '''
......
......@@ -150,7 +150,7 @@ class TestRecurrentOp(unittest.TestCase):
sig_op = Operator("sigmoid", X="sum", Y="h@alias")
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.add_op(op)
stepnet.append_op(op)
stepnet.complete_add_op(True)
self.rnnop.set_stepnet(stepnet)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册