提交 1795e576 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into lookup_table

...@@ -362,6 +362,11 @@ trans ...@@ -362,6 +362,11 @@ trans
.. autoclass:: paddle.v2.layer.trans .. autoclass:: paddle.v2.layer.trans
:noindex: :noindex:
scale_shift
-----------
.. autoclass:: paddle.v2.layer.scale_shift
:noindex:
Sampling Layers Sampling Layers
=============== ===============
......
...@@ -110,7 +110,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -110,7 +110,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
dup_output_ops[out].emplace_back(local_op_id); dup_output_ops[out].emplace_back(local_op_id);
return false; return false;
}); });
net->AddOp(std::move(bwd)); net->AppendOp(std::move(bwd));
} }
// Get unique ID for this method. // Get unique ID for this method.
auto uid = uniq_id++; auto uid = uniq_id++;
...@@ -163,8 +163,9 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -163,8 +163,9 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// If part of input gradient of that operator is not calculated, fill // If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient. // zero variables to that input gradient.
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {{"Src", {prefix}}}, net->AppendOp(OpRegistry::CreateOp("fill_zeros_like",
{{"Dst", {grad_input}}}, {})); {{"Src", {prefix}}},
{{"Dst", {grad_input}}}, {}));
} }
return false; return false;
}); });
...@@ -195,7 +196,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -195,7 +196,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
if (net->ops_.empty()) { // Current no aux op is added to network if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op; return grad_op;
} }
net->AddOp(std::move(grad_op)); net->AppendOp(std::move(grad_op));
} }
net->SetType("@GENERATED_BACKWARD@"); net->SetType("@GENERATED_BACKWARD@");
net->CompleteAddOp(); net->CompleteAddOp();
......
...@@ -75,13 +75,13 @@ class FcOp : public operators::NetOp { ...@@ -75,13 +75,13 @@ class FcOp : public operators::NetOp {
FcOp(const std::string &type, const VarNameMap &inputs, FcOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs, const AttributeMap &attrs) const VarNameMap &outputs, const AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) { : NetOp(type, inputs, outputs, attrs) {
AddOp(OpRegistry::CreateOp("mul", AppendOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}}, {{"X", {Input("X")}}, {"Y", {Input("W")}}},
{{"Out", {Output("mul_result")}}}, {})); {{"Out", {Output("mul_result")}}}, {}));
auto input_b = Inputs("b"); auto input_b = Inputs("b");
std::string before_act = "mul_result"; std::string before_act = "mul_result";
if (input_b.size() != 0) { if (input_b.size() != 0) {
AddOp(OpRegistry::CreateOp( AppendOp(OpRegistry::CreateOp(
"rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}}, "rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}},
{{"Out", {Output("add_result")}}}, {})); {{"Out", {Output("add_result")}}}, {}));
before_act = "add_result"; before_act = "add_result";
...@@ -92,8 +92,8 @@ class FcOp : public operators::NetOp { ...@@ -92,8 +92,8 @@ class FcOp : public operators::NetOp {
} }
} }
AddOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}}, AppendOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}},
{{"Out", {Output("Out")}}}, {})); {{"Out", {Output("Out")}}}, {}));
CompleteAddOp(false); CompleteAddOp(false);
} }
}; };
...@@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) { ...@@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) {
TEST(Backward, net_input_of_network_not_need_grad) { TEST(Backward, net_input_of_network_not_need_grad) {
ops::NetOp net; ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp( net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}}, "fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}},
{{"mul_result", {"mul_tmp_0"}}, {{"mul_result", {"mul_tmp_0"}},
{"add_result", {"add_tmp_0"}}, {"add_result", {"add_tmp_0"}},
{"Out", {"hidden0"}}}, {"Out", {"hidden0"}}},
{})); {}));
net.AddOp(f::OpRegistry::CreateOp( net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}}, "fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}},
{{"mul_result", {"mul_tmp_1"}}, {{"mul_result", {"mul_tmp_1"}},
{"add_result", {"add_tmp_1"}}, {"add_result", {"add_tmp_1"}},
...@@ -273,10 +273,10 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -273,10 +273,10 @@ TEST(Backward, net_input_of_network_not_need_grad) {
TEST(Backward, net_shared_weight) { TEST(Backward, net_shared_weight) {
ops::NetOp net; ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}}, net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}},
{{"Out", {"out"}}}, {})); {{"Out", {"out"}}}, {}));
net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}}, net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}},
{{"Out", {"FinalOut"}}}, {})); {{"Out", {"FinalOut"}}}, {}));
net.CompleteAddOp(); net.CompleteAddOp();
auto bwd = f::Backward(net, {}); auto bwd = f::Backward(net, {});
...@@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) { ...@@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) {
TEST(Backward, linear_net_intermediate_variable_has_no_grad) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
ops::NetOp net; ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp( net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}}, "fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"mul_result", {"mul_out1"}}, {{"mul_result", {"mul_out1"}},
{"add_result", {"add_out1"}}, {"add_result", {"add_out1"}},
{"Out", {"out1"}}}, {"Out", {"out1"}}},
{})); {}));
net.AddOp(f::OpRegistry::CreateOp( net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}}, "fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"mul_result", {"mul_out2"}}, {{"mul_result", {"mul_out2"}},
{"add_result", {"tmp_out2"}}, {"add_result", {"tmp_out2"}},
{"Out", {"out2"}}}, {"Out", {"out2"}}},
{})); {}));
net.AddOp(f::OpRegistry::CreateOp( net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}}, "fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}},
{{"mul_result", {"mul_out3"}}, {{"mul_result", {"mul_out3"}},
{"add_result", {"tmp_out3"}}, {"add_result", {"tmp_out3"}},
......
...@@ -223,8 +223,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -223,8 +223,8 @@ All parameter, weight, gradient are variables in Paddle.
retv->SetType("plain_net"); retv->SetType("plain_net");
return retv; return retv;
}) })
.def("add_op", [](operators::NetOp &self, .def("append_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AddOp(op); }) const OperatorBase &op) { self.AppendOp(op); })
.def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) { .def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp(); self->CompleteAddOp();
......
...@@ -84,7 +84,7 @@ struct BlasGemm<DEVICE_TYPE_GPU, T> { ...@@ -84,7 +84,7 @@ struct BlasGemm<DEVICE_TYPE_GPU, T> {
} }
}; };
template class BlasGemm<DEVICE_TYPE_CPU, real>; template struct BlasGemm<DEVICE_TYPE_CPU, real>;
template class BlasGemm<DEVICE_TYPE_GPU, real>; template struct BlasGemm<DEVICE_TYPE_GPU, real>;
} // namespace paddle } // namespace paddle
...@@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector<Argument>& inArgs) { ...@@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector<Argument>& inArgs) {
auto mat = dynamic_cast<SparsePrefetchRowCpuMatrix*>( auto mat = dynamic_cast<SparsePrefetchRowCpuMatrix*>(
para->getMat(PARAMETER_VALUE).get()); para->getMat(PARAMETER_VALUE).get());
para->clearGradient(); para->clearGradient();
mat->clearIndices(); if (mat) mat->clearIndices();
} }
} }
} }
......
...@@ -184,7 +184,7 @@ public: ...@@ -184,7 +184,7 @@ public:
} }
void backward(const UpdateCallback& callback) override { void backward(const UpdateCallback& callback) override {
if (biases_) { if (biases_ && biases_->getWGrad()) {
backwardActivation(); backwardActivation();
biases_->getWGrad()->collectBias(*getOutputGrad(), 1); biases_->getWGrad()->collectBias(*getOutputGrad(), 1);
biases_->getParameterPtr()->incUpdate(callback); biases_->getParameterPtr()->incUpdate(callback);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
namespace paddle {
/**
* A layer applies a linear transformation to each element in each row of
* the input matrix. For each element, the layer first re-scale it and then
* adds a bias to it.
*
* \f[
* y = wx + b
* \f]
*
* Here, w is the scale and b is the bias. Both w and b are trainable scalars.
*
*/
class ScaleShiftLayer : public Layer {
protected:
std::unique_ptr<Weight> scale_;
std::unique_ptr<Weight> offset_;
public:
explicit ScaleShiftLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(scale_shift, ScaleShiftLayer);
bool ScaleShiftLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 1U);
scale_.reset(new Weight(1, 1, parameters_[0]));
if (biasParameter_.get() != NULL) {
offset_ = std::unique_ptr<Weight>(new Weight(1, 1, biasParameter_));
}
return true;
}
void ScaleShiftLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr inV = getInputValue(0);
resetOutput(inV->getHeight(), inV->getWidth());
MatrixPtr outV = getOutputValue();
real scaleValue = scale_->getW()->getElement(0, 0);
outV->mulScalar(*inV, scaleValue);
if (offset_) {
real offsetValue = offset_->getW()->getElement(0, 0);
outV->add(offsetValue);
}
}
void ScaleShiftLayer::backward(const UpdateCallback& callback) {
MatrixPtr inV = getInputValue(0);
MatrixPtr inG = getInputGrad(0);
MatrixPtr outV = getOutputValue();
MatrixPtr outG = getOutputGrad();
/* Calculate the parameter gradient for the current layer */
if (scale_->getWGrad()) {
MatrixPtr rowSumMtx;
Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_);
// this_i = scaleDest * this_i + scaleSum * \sum_j b_{ij} * c_{ij}
rowSumMtx->sumOfProducts(
/* b= */ *inV, /* c= */ *outG, /* scaleSum= */ 1, /* scaleDest= */ 0.);
// this_i = scaleDest * this_i + scaleSum * \sum_j b_{ji}
scale_->getWGrad()->sumCols(
/* b= */ *rowSumMtx, /* scaleSum= */ 1., /* scaleDest= */ 1.);
scale_->getParameterPtr()->incUpdate(callback);
}
if (offset_ && offset_->getWGrad()) {
MatrixPtr rowSumMtx;
Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_);
rowSumMtx->sumRows(*outG, 1., 0.);
offset_->getWGrad()->sumCols(*rowSumMtx, 1., 1.);
offset_->getParameterPtr()->incUpdate(callback);
}
/* Calculate the input layers error */
if (inG) {
real scaleValue = scale_->getW()->getElement(0, 0);
inG->add(*outG, scaleValue);
}
}
} // namespace paddle
...@@ -2007,6 +2007,21 @@ TEST(Layer, RowL2NormLayer) { ...@@ -2007,6 +2007,21 @@ TEST(Layer, RowL2NormLayer) {
} }
} }
TEST(Layer, ScaleShiftLayer) {
const size_t batchSize = 16;
const size_t size = 32;
TestConfig config;
config.layerConfig.set_type("scale_shift");
config.layerConfig.set_size(size);
config.biasSize = 1;
config.inputDefs.push_back(
{INPUT_DATA, "input", /* dim= */ size, /* paraSize= */ 1});
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "scale_shift", batchSize, false, useGpu, false);
}
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
initMain(argc, argv); initMain(argc, argv);
......
...@@ -269,7 +269,8 @@ TEST(Compare, img_conv2) { ...@@ -269,7 +269,8 @@ TEST(Compare, img_conv2) {
bool useGpu = FLAGS_use_gpu; bool useGpu = FLAGS_use_gpu;
double eps = FLAGS_checkgrad_eps; double eps = FLAGS_checkgrad_eps;
FLAGS_use_gpu = true; FLAGS_use_gpu = true;
FLAGS_checkgrad_eps = 1e-2; // Sometimes, this unit test will fail with 1e-2
FLAGS_checkgrad_eps = 4e-2;
compareNetwork(config_file_a, config_file_b); compareNetwork(config_file_a, config_file_b);
FLAGS_use_gpu = useGpu; FLAGS_use_gpu = useGpu;
FLAGS_checkgrad_eps = eps; FLAGS_checkgrad_eps = eps;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -19,25 +16,25 @@ namespace paddle { ...@@ -19,25 +16,25 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class GaussianRandomKernel : public framework::OpKernel { class CPUGaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean"); float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std"); float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
// TODO(dzh): attribute does not support unsigned int. unsigned int seed =
// And we need a global random seed configuration. static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
int seed = context.op_.GetAttr<int>("seed"); std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
} }
std::mt19937 g(seed); engine.seed(seed);
std::normal_distribution<T> distribution(mean, std); std::normal_distribution<T> dist(mean, std);
ssize_t size = framework::product(tensor->dims()); ssize_t size = framework::product(tensor->dims());
for (int i = 0; i < size; ++i) { for (ssize_t i = 0; i < size; ++i) {
data[i] = distribution(g); data[i] = dist(engine);
} }
} }
}; };
...@@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext& context) const override { void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0); auto* tensor = context.Output<framework::Tensor>("Out");
auto dims = GetAttr<std::vector<int>>("dims"); auto dims = GetAttr<std::vector<int>>("dims");
PADDLE_ENFORCE(dims.size() > 0UL, PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set."); "dims can be one int or array. dims must be set.");
...@@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator. ...@@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator.
)DOC"); )DOC");
AddAttr<std::vector<int>>("dims", "The dimension of random tensor."); AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0f); AddAttr<float>("mean", "mean of random tensor.").SetDefault(.0f);
AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f); AddAttr<float>("std", "std of random tensor.").SetDefault(1.0f);
AddAttr<int>("seed", AddAttr<int>("seed",
"Random seed of generator." "Random seed of generator."
"0 means use system wide seed") "0 means use system wide seed")
...@@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator. ...@@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker); ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>);
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory> #include <thrust/device_ptr.h>
#include <random> #include <thrust/iterator/counting_iterator.h>
#include "paddle/platform/dynload/curand.h" #include <thrust/random.h>
#include "paddle/platform/gpu_info.h" #include <thrust/transform.h>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class GaussianRandomKernel : public framework::OpKernel { struct GaussianGenerator {
T mean_, std_;
unsigned int seed_;
__host__ __device__ GaussianGenerator(T mean, T std, int seed)
: mean_(mean), std_(std), seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::normal_distribution<T> dist(mean_, std_);
rng.discard(n);
return dist(rng);
}
};
template <typename T>
class GPUGaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean"); auto* tensor = context.Output<framework::Tensor>("Out");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
int seed = context.op_.GetAttr<int>("seed"); static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
if (seed == 0) { if (seed == 0) {
std::random_device rd; std::random_device rd;
seed = rd(); seed = rd();
} }
curandGenerator_t g; T mean = static_cast<T>(context.op_.GetAttr<float>("mean"));
PADDLE_ENFORCE(platform::dynload::curandCreateGenerator( T std = static_cast<T>(context.op_.GetAttr<float>("std"));
&g, CURAND_RNG_PSEUDO_DEFAULT)); thrust::counting_iterator<unsigned int> index_sequence_begin(0);
PADDLE_ENFORCE( ssize_t N = framework::product(tensor->dims());
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); thrust::transform(index_sequence_begin, index_sequence_begin + N,
platform::dynload::curandGenerateNormal( thrust::device_ptr<T>(data),
g, data, framework::product(tensor->dims()), mean, std); GaussianGenerator<T>(mean, std, seed));
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(gaussian_random,
REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); paddle::operators::GPUGaussianRandomKernel<float>);
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -84,13 +84,14 @@ class NetOp : public framework::OperatorBase { ...@@ -84,13 +84,14 @@ class NetOp : public framework::OperatorBase {
return true; return true;
} }
void AddOp(const framework::OperatorBase& op) { AddOp(op.Clone()); } void AppendOp(const framework::OperatorBase& op) { AppendOp(op.Clone()); }
/** /**
* @brief Add an operator by ptr * @brief Add an operator by ptr
*/ */
void AddOp(std::unique_ptr<framework::OperatorBase> op) { void AppendOp(std::unique_ptr<framework::OperatorBase> op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE(!add_op_done_,
"Cannot AppendOp when this network is sealed");
PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
ops_.push_back(std::move(op)); ops_.push_back(std::move(op));
} }
......
...@@ -38,10 +38,10 @@ TEST(OpKernel, all) { ...@@ -38,10 +38,10 @@ TEST(OpKernel, all) {
auto net = std::make_shared<NetOp>(); auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr); ASSERT_NE(net, nullptr);
net->AddOp(std::unique_ptr<TestOp>( net->AppendOp(std::unique_ptr<TestOp>(
new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {}))); {{"Out", {"y"}}}, {})));
net->AddOp(std::unique_ptr<TestOp>( net->AppendOp(std::unique_ptr<TestOp>(
new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}, new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"Out", {"z"}}}, {}))); {{"Out", {"z"}}}, {})));
...@@ -61,7 +61,7 @@ TEST(NetOp, insert_op) { ...@@ -61,7 +61,7 @@ TEST(NetOp, insert_op) {
auto op1 = std::unique_ptr<framework::NOP>( auto op1 = std::unique_ptr<framework::NOP>(
new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {})); {{"Out", {"y"}}}, {}));
net.AddOp(*op1); net.AppendOp(*op1);
net.InsertOp(0, *op1); net.InsertOp(0, *op1);
ASSERT_EQ(2UL, net.ops_.size()); ASSERT_EQ(2UL, net.ops_.size());
net.InsertOp(2, std::move(op1)); net.InsertOp(2, std::move(op1));
...@@ -70,9 +70,9 @@ TEST(NetOp, insert_op) { ...@@ -70,9 +70,9 @@ TEST(NetOp, insert_op) {
TEST(NetOp, Clone) { TEST(NetOp, Clone) {
NetOp net; NetOp net;
net.AddOp( net.AppendOp(
std::unique_ptr<framework::NOP>(new framework::NOP{"empty", {}, {}, {}})); std::unique_ptr<framework::NOP>(new framework::NOP{"empty", {}, {}, {}}));
net.AddOp(std::unique_ptr<framework::NOP>( net.AppendOp(std::unique_ptr<framework::NOP>(
new framework::NOP{"empty2", {}, {}, {}})); new framework::NOP{"empty2", {}, {}, {}}));
net.CompleteAddOp(true); net.CompleteAddOp(true);
auto new_net_op = net.Clone(); auto new_net_op = net.Clone();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
...@@ -63,7 +63,7 @@ class RowwiseAddGradKernel : public framework::OpKernel { ...@@ -63,7 +63,7 @@ class RowwiseAddGradKernel : public framework::OpKernel {
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add // colwise add
Eigen::array<int, 1> dims{{1}}; /* dimension to reduce */ Eigen::array<int, 1> dims{{0}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims); EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims);
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel { ...@@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel {
std::uniform_real_distribution<T> dist( std::uniform_real_distribution<T> dist(
static_cast<T>(context.op_.GetAttr<float>("min")), static_cast<T>(context.op_.GetAttr<float>("min")),
static_cast<T>(context.op_.GetAttr<float>("max"))); static_cast<T>(context.op_.GetAttr<float>("max")));
for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { ssize_t size = framework::product(tensor->dims());
for (ssize_t i = 0; i < size; ++i) {
data[i] = dist(engine); data[i] = dist(engine);
} }
} }
...@@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The output tensor of uniform random op"); AddOutput("Out", "The output tensor of uniform random op");
AddComment(R"DOC(Uniform random operator. AddComment(R"DOC(Uniform random operator.
Used to initialize tensor with uniform random generator. Used to initialize tensor with uniform random generator.
)DOC"); )DOC");
AddAttr<std::vector<int>>("dims", "the dimension of random tensor"); AddAttr<std::vector<int>>("dims", "the dimension of random tensor");
...@@ -84,4 +81,4 @@ Used to initialize tensor with uniform random generator. ...@@ -84,4 +81,4 @@ Used to initialize tensor with uniform random generator.
REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker); paddle::operators::UniformRandomOpMaker);
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>); paddle::operators::CPUUniformRandomKernel<float>);
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -68,4 +65,4 @@ class GPUUniformRandomKernel : public framework::OpKernel { ...@@ -68,4 +65,4 @@ class GPUUniformRandomKernel : public framework::OpKernel {
} // namespace paddle } // namespace paddle
REGISTER_OP_GPU_KERNEL(uniform_random, REGISTER_OP_GPU_KERNEL(uniform_random,
paddle::operators::GPUUniformRandomKernel<float>); paddle::operators::GPUUniformRandomKernel<float>);
\ No newline at end of file
...@@ -65,7 +65,10 @@ public: ...@@ -65,7 +65,10 @@ public:
size_t getSize() const { return config_.size(); } size_t getSize() const { return config_.size(); }
bool isFullSize() const { bool isFullSize() const {
return this->getSize() == bufs_[PARAMETER_VALUE]->getSize(); if (bufs_[PARAMETER_VALUE]) {
return this->getSize() == bufs_[PARAMETER_VALUE]->getSize();
}
return false;
} }
inline bool useGpu() const { return useGpu_; } inline bool useGpu() const { return useGpu_; }
......
...@@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
} }
if (curand_generator_) {
PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_));
}
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
...@@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { ...@@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
cudaStream_t CUDADeviceContext::stream() { return stream_; } cudaStream_t CUDADeviceContext::stream() { return stream_; }
curandGenerator_t CUDADeviceContext::curand_generator() {
if (!curand_generator_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
}
return curand_generator_;
}
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
} // namespace platform } // namespace platform
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
...@@ -40,7 +39,7 @@ class DeviceContext { ...@@ -40,7 +39,7 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
CPUDeviceContext(); CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace); explicit CPUDeviceContext(CPUPlace place);
virtual ~CPUDeviceContext() {} virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const; Eigen::DefaultDevice* eigen_device() const;
...@@ -56,7 +55,7 @@ class EigenCudaStreamDevice; ...@@ -56,7 +55,7 @@ class EigenCudaStreamDevice;
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
explicit CUDADeviceContext(GPUPlace); explicit CUDADeviceContext(GPUPlace place);
virtual ~CUDADeviceContext(); virtual ~CUDADeviceContext();
/*! \brief Wait for all operations completion in the stream. */ /*! \brief Wait for all operations completion in the stream. */
...@@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle(); cudnnHandle_t cudnn_handle();
/*! \brief Return curand handle in the device context. */
curandGenerator_t curand_generator();
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
cudaStream_t stream(); cudaStream_t stream();
// clang-format on // clang-format on
...@@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext { ...@@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext {
private: private:
GPUPlace place_; GPUPlace place_;
private:
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
private:
uint64_t seed_;
// clang-format off // clang-format off
cudaStream_t stream_{nullptr}; cudaStream_t stream_{nullptr};
cudnnHandle_t cudnn_handle_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr};
cublasHandle_t cublas_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr};
curandGenerator_t curand_generator_{nullptr};
// clang-format on // clang-format on
}; };
......
...@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) { ...@@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle(); cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle); ASSERT_NE(nullptr, cublas_handle);
curandGenerator_t curand_handle = device_context->curand_generator();
ASSERT_NE(nullptr, curand_handle);
ASSERT_NE(nullptr, device_context->stream()); ASSERT_NE(nullptr, device_context->stream());
delete device_context; delete device_context;
} }
......
...@@ -65,7 +65,6 @@ void ParameterClient2::initThreads() { ...@@ -65,7 +65,6 @@ void ParameterClient2::initThreads() {
LOG(INFO) << "parallel_thread_num dosent need to set"; LOG(INFO) << "parallel_thread_num dosent need to set";
} }
syncThreadPool_.reset(new SyncThreadPool(threadNum_)); syncThreadPool_.reset(new SyncThreadPool(threadNum_));
startThreads(); startThreads();
} }
...@@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData( ...@@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData(
request.set_cost(cost); request.set_cost(cost);
request.set_batch_status(batchStatus); request.set_batch_status(batchStatus);
CHECK_EQ(request.blocks_size(), 0); CHECK_EQ(request.blocks_size(), 0);
VLOG(10) << "request: trainer_id: " << request.trainer_id()
<< " update_mode" << request.update_mode()
<< " send_back_parameter: " << request.send_back_parameter()
<< " send_back_parameter_type: "
<< request.send_back_parameter_type()
<< " num_samples: " << request.num_samples()
<< " cost: " << request.cost()
<< " batch_status: " << request.batch_status();
} }
for (const auto& segments : parameterSegments) { for (const auto& segments : parameterSegments) {
const auto it = parameterMap_.find(segments.id); const auto it = parameterMap_.find(segments.id);
...@@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData( ...@@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData(
CHECK(sendMat != nullptr) << "sendMat is nullptr"; CHECK(sendMat != nullptr) << "sendMat is nullptr";
syncThreadPool_->exec([&](int tid, size_t numThreads) { syncThreadPool_->exec([&](int tid, size_t numThreads) {
std::lock_guard<std::mutex> guard(sparseAutoGrowthMutex_);
const auto& localIndices = prefetchMat->getLocalIndices(); const auto& localIndices = prefetchMat->getLocalIndices();
/// num of sparse rows /// num of sparse rows
size_t nLocalBlocks = localIndices.size(); size_t nLocalBlocks = localIndices.size();
uint64_t beginDim = 0; uint64_t beginDim = 0;
uint64_t endDim = 0; uint64_t endDim = 0;
// FIXME(typhoonzero): let it resize first
prefetchMat->getLocalRow(nLocalBlocks + 1);
sendMat->getLocalRow(nLocalBlocks + 1);
for (size_t row = 0; row < nLocalBlocks; ++row) { for (size_t row = 0; row < nLocalBlocks; ++row) {
int64_t blockId = localIndices[row]; // local row -> sparse row int64_t blockId = localIndices[row]; // local row -> sparse row
int serverId = std::abs((blockId + nameHash) % serviceNum_); int serverId = std::abs((blockId + nameHash) % serviceNum_);
...@@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData( ...@@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData(
block->set_begin_pos(row * blockSize); block->set_begin_pos(row * blockSize);
/// block len /// block len
block->set_block_size(endDim - beginDim); block->set_block_size(endDim - beginDim);
if (sendingPara) { if (sendingPara) {
sendJob->parallelInputIovs[serverId].push_back( sendJob->parallelInputIovs[serverId].push_back(
{sendMat->getLocalRow(row), sizeof(real) * (size_t)blockSize}); {sendMat->getLocalRow(row), sizeof(real) * (size_t)blockSize});
......
...@@ -583,6 +583,7 @@ protected: ...@@ -583,6 +583,7 @@ protected:
#ifndef PADDLE_DISABLE_TIMER #ifndef PADDLE_DISABLE_TIMER
uint64_t forwardbackwordTime_; uint64_t forwardbackwordTime_;
#endif #endif
std::mutex sparseAutoGrowthMutex_;
/// map id to parameter used for decoding protobuf data /// map id to parameter used for decoding protobuf data
std::unordered_map<size_t, ParameterPtr> parameterMap_; std::unordered_map<size_t, ParameterPtr> parameterMap_;
......
...@@ -2232,6 +2232,20 @@ class ClipLayer(LayerBase): ...@@ -2232,6 +2232,20 @@ class ClipLayer(LayerBase):
self.config.inputs[0].clip_conf.max = max self.config.inputs[0].clip_conf.max = max
@config_layer('scale_shift')
class ScaleShiftLayer(LayerBase):
def __init__(self, name, inputs, bias=True, **xargs):
super(ScaleShiftLayer, self).__init__(
name, 'scale_shift', 0, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'ScaleShiftLayer must have one and only one input.')
input_layer = self.get_input_layer(0)
self.set_layer_size(input_layer.size)
self.create_input_parameter(0, 1, [1, 1])
self.create_bias_parameter(bias, 1)
# key: cost type # key: cost type
# value: cost class # value: cost class
g_cost_map = {} g_cost_map = {}
......
...@@ -133,6 +133,7 @@ __all__ = [ ...@@ -133,6 +133,7 @@ __all__ = [
'clip_layer', 'clip_layer',
'slice_projection', 'slice_projection',
'kmax_sequence_score_layer', 'kmax_sequence_score_layer',
'scale_shift_layer',
] ]
...@@ -230,6 +231,7 @@ class LayerType(object): ...@@ -230,6 +231,7 @@ class LayerType(object):
CLIP_LAYER = 'clip' CLIP_LAYER = 'clip'
KMAX_SEQ_SCORE = 'kmax_seq_score' KMAX_SEQ_SCORE = 'kmax_seq_score'
SCALE_SHIFT_LAYER = 'scale_shift'
@staticmethod @staticmethod
def is_layer_type(type_name): def is_layer_type(type_name):
...@@ -6210,3 +6212,43 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1): ...@@ -6210,3 +6212,43 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1):
return LayerOutput( return LayerOutput(
name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size) name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size)
@wrap_name_default("scale_shift")
@wrap_param_attr_default()
@wrap_bias_attr_default()
def scale_shift_layer(input, name=None, param_attr=None, bias_attr=None):
"""
A layer applies a linear transformation to each element in each row of
the input matrix. For each element, the layer first re-scale it and then
adds a bias to it.
This layer is very like the SlopeInterceptLayer, except the scale and
bias are trainable.
.. math::
y = w * x + b
.. code-block:: python
scale_shift = scale_shift_layer(input=input_layer, bias_attr=False)
:param name: The Layer Name.
:type name: basestring
:param input: The input layer.
:type input: LayerOutput.
:param param_attr: The parameter attribute of scaling.
:type param_attr: ParameterAttribute
:param bias_attr: The parameter attribute of shifting.
:type bias_attr: ParameterAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
Layer(
name=name,
type=LayerType.SCALE_SHIFT_LAYER,
inputs=Input(input.name, **param_attr.attr),
bias=ParamAttr.to_bias(bias_attr))
return LayerOutput(
name, LayerType.SCALE_SHIFT_LAYER, parents=[input], size=input.size)
...@@ -8,6 +8,6 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops ...@@ -8,6 +8,6 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_seq_select_layers) test_kmax_seq_socre_layer test_seq_select_layers test_scale_shift_layer)
export whole_configs=(test_split_datasource) export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "data"
type: "data"
size: 100
active_type: ""
}
layers {
name: "__scale_shift_0__"
type: "scale_shift"
size: 100
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "___scale_shift_0__.w0"
}
}
layers {
name: "__scale_shift_1__"
type: "scale_shift"
size: 100
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "___scale_shift_1__.w0"
}
bias_parameter_name: "___scale_shift_1__.wbias"
}
parameters {
name: "___scale_shift_0__.w0"
size: 1
initial_mean: 0.0
initial_std: 1.0
dims: 1
dims: 1
initial_strategy: 0
initial_smart: true
}
parameters {
name: "___scale_shift_1__.w0"
size: 1
initial_mean: 0.0
initial_std: 1.0
dims: 1
dims: 1
initial_strategy: 0
initial_smart: true
}
parameters {
name: "___scale_shift_1__.wbias"
size: 1
initial_mean: 0.0
initial_std: 0.0
dims: 1
dims: 1
initial_strategy: 0
initial_smart: false
}
input_layer_names: "data"
output_layer_names: "__scale_shift_0__"
output_layer_names: "__scale_shift_1__"
sub_models {
name: "root"
layer_names: "data"
layer_names: "__scale_shift_0__"
layer_names: "__scale_shift_1__"
input_layer_names: "data"
output_layer_names: "__scale_shift_0__"
output_layer_names: "__scale_shift_1__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
data = data_layer(name='data', size=100)
scale = scale_shift_layer(input=data, bias_attr=False)
scale_shift = scale_shift_layer(input=data)
outputs(scale, scale_shift)
...@@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) ...@@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_operator SRCS test_operator.py) py_test(test_operator SRCS test_operator.py)
# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_sgd_op SRCS test_sgd_op.py)
......
...@@ -6,8 +6,8 @@ import unittest ...@@ -6,8 +6,8 @@ import unittest
def fc(X, W, Y): def fc(X, W, Y):
ret_v = core.Net.create() ret_v = core.Net.create()
ret_v.add_op(Operator("mul", X="X", Y="W", Out="pre_activation")) ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation"))
ret_v.add_op(Operator("sigmoid", X="pre_activation", Y=Y)) ret_v.append_op(Operator("sigmoid", X="pre_activation", Y=Y))
ret_v.complete_add_op(True) ret_v.complete_add_op(True)
return ret_v return ret_v
...@@ -16,12 +16,12 @@ class TestNet(unittest.TestCase): ...@@ -16,12 +16,12 @@ class TestNet(unittest.TestCase):
def test_net_all(self): def test_net_all(self):
net = core.Net.create() net = core.Net.create()
op1 = Operator("add_two", X="X", Y="Y", Out="Out") op1 = Operator("add_two", X="X", Y="Y", Out="Out")
net.add_op(op1) net.append_op(op1)
net2 = core.Net.create() net2 = core.Net.create()
net2.add_op(fc(X="X", W="w", Y="fc.out")) net2.append_op(fc(X="X", W="w", Y="fc.out"))
net2.complete_add_op(True) net2.complete_add_op(True)
net.add_op(net2) net.append_op(net2)
net.complete_add_op(True) net.complete_add_op(True)
expected = ''' expected = '''
......
...@@ -150,7 +150,7 @@ class TestRecurrentOp(unittest.TestCase): ...@@ -150,7 +150,7 @@ class TestRecurrentOp(unittest.TestCase):
sig_op = Operator("sigmoid", X="sum", Y="h@alias") sig_op = Operator("sigmoid", X="sum", Y="h@alias")
for op in [x_fc_op, h_fc_op, sum_op, sig_op]: for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.add_op(op) stepnet.append_op(op)
stepnet.complete_add_op(True) stepnet.complete_add_op(True)
self.rnnop.set_stepnet(stepnet) self.rnnop.set_stepnet(stepnet)
......
...@@ -20,7 +20,7 @@ class RowwiseAddGradOpTest(GradientChecker): ...@@ -20,7 +20,7 @@ class RowwiseAddGradOpTest(GradientChecker):
def test_rowwise_add(self): def test_rowwise_add(self):
op = create_op("rowwise_add") op = create_op("rowwise_add")
inputs = { inputs = {
"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32"), "X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"),
"b": np.random.uniform(0.1, 1, [10]).astype("float32") "b": np.random.uniform(0.1, 1, [10]).astype("float32")
} }
self.check_grad(op, inputs, set(["X", "b"]), "Out") self.check_grad(op, inputs, set(["X", "b"]), "Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册