提交 f087533c 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into executor_impl

......@@ -15,9 +15,9 @@ Please be aware that these Python classes need to maintain some construction-tim
### Program
A `ProgramDesc` describes a [DL program](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md), which is composed of an array of `BlockDesc`s. A `BlockDesc` refers to its parent block by its index in the array. For example, operators in the step block of an RNN operator needs to be able to access variables in its ancessor blocks.
A `ProgramDesc` describes a [DL program](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md), which is composed of an array of `BlockDesc`s. The `BlockDesc`s in a `ProgramDesc` can have a tree-like hierarchical structure. However, the `ProgramDesc` onlys stores a flattened array of `BlockDesc`s. A `BlockDesc` refers to its parent block by its index in the array. For example, operators in the step block of an RNN operator need to be able to access variables in its ancestor blocks.
Whenever we create a block, we need set its parent block to the current block, so the Python class `Program` needs to maintain a data member `current_block`.
Whenever we create a block, we need to set its parent block to the current block, hence the Python class `Program` needs to maintain a data member `current_block`.
```python
class Program(objects):
......@@ -81,13 +81,13 @@ class Block(objects):
self.ops.prepend(Operator(self, ...))
```
`create_parameter` is necessary because parameters are global variables, those defined in the global block, but can be created in some sub-blocks, e.g., an FC layer in the step block of an RNN operator.
`create_parameter` is necessary because parameters are global variables, defined in the global block, but can be created in some sub-blocks. For example, an FC layer in the step block of an RNN operator.
`prepand_operator` is necessary because the constructor of `Parameter` needs to create the initialize (or load) operator of the parameter, and would like to put it in the *preamble* of the global block.
`prepend_operator` is necessary because the constructor of `Parameter` needs to create the initialize (or load) operator of the parameter, and would like to put it in the *preamble* of the global block.
### Operator
The `Operator` class fills in the `OpDesc` message and calls the C++ function `InferShape` to infer output shape from input shape.
The `Operator` class fills in the `OpDesc` message and calls the C++ function `InferShape` to infer the output shapes from the input shapes.
```python
class Operator(object):
......@@ -105,7 +105,7 @@ class Operator(object):
return self.proto.type()
```
`Operator` creates the `OpDesc` message in C++ space, so could it call the `InferShape` function, which is in C++.
`Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++.
### Variable
......@@ -128,7 +128,7 @@ class Variable(object):
self.writer = None
```
Please be aware of `self.writer`, that tracks operator who creates the variable. It possible that there are more than one operators who write a variable, but in Python space, each writes to a variable is represented by a Variable class. This is guaranteed by the fact that **`core.NewVarDesc` must NOT create a new `VarDesc` message if its name already exists in the specified block**.
Please be aware of `self.writer`, that tracks operator who creates the variable. It possible that there are more than one operators who write a variable, but in Python space, each write to a variable is represented by a Variable class. This is guaranteed by the fact that **`core.NewVarDesc` must NOT create a new `VarDesc` message if its name already exists in the specified block**.
### Parameter
......@@ -155,7 +155,7 @@ class Parameter(Variable):
initialize_op_attrs)
```
When users create a parameter, s/he can call
When users create a parameter, they can call
```python
program.create_parameter(
......
......@@ -23,7 +23,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc)
......
......@@ -34,6 +34,10 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const {
return it->second.get();
}
bool BlockDescBind::HasVar(const std::string &name) const {
return vars_.find(name) != vars_.end();
}
std::vector<VarDescBind *> BlockDescBind::AllVars() const {
std::vector<VarDescBind *> res;
for (const auto &p : vars_) {
......
......@@ -43,6 +43,8 @@ class BlockDescBind {
VarDescBind *Var(const std::string &name_bytes) const;
bool HasVar(const std::string &var_name) const;
std::vector<VarDescBind *> AllVars() const;
BlockDescBind *ParentBlock() const;
......
......@@ -106,6 +106,7 @@ message LoDTensorDesc {
message VarDesc {
required string name = 1;
optional LoDTensorDesc lod_tensor = 2;
optional bool persistable = 3 [ default = false ];
}
message BlockDesc {
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "op_info.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
......@@ -317,26 +318,122 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext& device_context_;
};
class CompileTimeInferShapeContext : public InferShapeContextBase {
public:
CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
: op_(op), block_(block) {}
bool HasInput(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name);
auto length = input_names.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVar(input_names[0]);
}
bool HasOutput(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name);
auto length = output_names.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Output(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVar(output_names[0]);
}
bool HasInputs(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name);
PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name);
for (auto& input : input_names) {
if (!block_.HasVar(input)) return false;
}
return true;
}
bool HasOutputs(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name);
PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name);
for (auto& output : output_names) {
if (!block_.HasVar(output)) return false;
}
return true;
}
DDim GetInputDim(const std::string& name) const override {
std::vector<DDim> ddims = GetInputsDim(name);
auto length = ddims.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have 1 value, "
"but it has %d now",
name, length);
return ddims[0];
}
void SetInputDim(const std::string& name, const DDim& dim) override {
SetInputsDim(name, {dim});
}
DDim GetOutputDim(const std::string& name) const override {
std::vector<DDim> ddims = GetOutputsDim(name);
auto length = ddims.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Output(%s) should have 1 value, "
"but it has %d now",
name, length);
return ddims[0];
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
SetOutputsDim(name, {dim});
}
AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Input(name);
}
const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Output(name);
}
private:
DDim GetDim(const std::string& name) const override {
return framework::make_ddim(block_.Var(name)->Shape());
}
void SetDim(const std::string& name, const DDim& dim) override {
block_.Var(name)->SetShape(framework::vectorize(dim));
}
const OpDescBind& op_;
const BlockDescBind& block_;
};
class RuntimeInferShapeContext : public InferShapeContextBase {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const {
bool HasInput(const std::string& name) const override {
auto ipt = op_.Input(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
bool HasOutput(const std::string& name) const {
bool HasOutput(const std::string& name) const override {
auto ipt = op_.Output(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}
bool HasInputs(const std::string& name) const {
bool HasInputs(const std::string& name) const override {
auto inputs = op_.Inputs(name);
if (inputs.size() == 0UL) {
if (inputs.empty()) {
return false;
}
for (auto& input : inputs) {
......@@ -347,9 +444,9 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return true;
}
bool HasOutputs(const std::string& name) const {
bool HasOutputs(const std::string& name) const override {
auto outputs = op_.Outputs(name);
if (outputs.size() == 0UL) {
if (outputs.empty()) {
return false;
}
for (auto& output : outputs) {
......@@ -360,29 +457,31 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return true;
}
DDim GetInputDim(const std::string& name) const {
DDim GetInputDim(const std::string& name) const override {
return GetDim(op_.Input(name));
}
void SetInputDim(const std::string& name, const DDim& dim) {
void SetInputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Input(name), dim);
}
DDim GetOutputDim(const std::string& name) const {
DDim GetOutputDim(const std::string& name) const override {
return GetDim(op_.Output(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) {
void SetOutputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Output(name), dim);
}
AttrReader Attrs() const { return AttrReader(op_.Attrs()); }
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs(const std::string& name) const {
const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Inputs(name);
}
const std::vector<std::string>& Outputs(const std::string& name) const {
const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Outputs(name);
}
......@@ -403,11 +502,11 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return t;
}
DDim GetDim(const std::string& name) const {
DDim GetDim(const std::string& name) const override {
return GetTensor<false>(name)->dims();
}
void SetDim(const std::string& name, const DDim& dim) {
void SetDim(const std::string& name, const DDim& dim) override {
GetTensor<true>(name)->Resize(dim);
}
......@@ -513,9 +612,9 @@ class OperatorWithKernel : public OperatorBase {
});
}
protected:
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
......
......@@ -19,6 +19,9 @@ limitations under the License. */
namespace paddle {
namespace framework {
// TODO(longfei): Once after both CompileTimeInferShapeContext and
// RuntimeInferShapeContext get merged, we can rename InferShapeContextBase into
// InferShapeContext so to replace the current InferShapeContext.
class InferShapeContextBase {
public:
virtual ~InferShapeContextBase() {}
......
......@@ -217,12 +217,11 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
// collect indice need to copy to the batch
std::vector<size_t> indice;
for (size_t seq_id = 0; seq_id < meta.size(); seq_id++) {
const auto& seq_meta = meta[seq_id];
if (index >= seq_meta.end) break;
indice.push_back(seq_meta.begin + index);
for (const auto& seq : meta) {
size_t id = seq.begin + index;
if (id >= seq.end) break;
indice.push_back(id);
}
PADDLE_ENFORCE(!indice.empty(), "invalid batch at %d", index);
// copy the indice of records in LoDTensor
......@@ -232,16 +231,18 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
result.Resize(make_ddim(record_dims_vec));
result.mutable_data<value_type>(platform::CPUPlace());
for (size_t i = 0; i < indice.size() - 1; i++) {
for (size_t i = 0; i < indice.size(); i++) {
auto index = indice[i];
auto target = result.Slice<value_type>(i, i + 1);
auto source_ = source->Slice<value_type>(index, index + 1);
target.CopyFrom<value_type>(source_, platform::CPUPlace());
}
return result;
}
// TODO(supejom) to cache lod if reasonable
LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
const std::vector<DySeqMeta>& meta, const LoD& lod,
size_t level) {
......@@ -273,7 +274,6 @@ LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
}
result.set_lod(lod);
return result;
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/adagrad_op.h"
namespace paddle {
namespace operators {
class AdagradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
"Output(MomentOut) of AdagradOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"LearningRate should have one element");
auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdagradOp should have the same dimension.");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment"),
"Param and Moment input of AdagradOp should have the same dimension.");
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims);
}
};
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdagradOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment");
AddInput("LearningRate", "(Tensor) Learning rate");
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("MomentOut", "(Tensor) Output second moment");
AddAttr<float>("epsilon",
"(float, default 1.0e-6) "
"Constant for numerical stability")
.SetDefault(1.0e-6f);
AddComment(R"DOC(
Adaptive Gradient Algorithm (Adagrad).
moment_out = moment + grad * grad
param_out = param - learning_rate * grad / (sqrt(moment_out) + epsilon)
The original paper(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
does not have the epsilon attribute. It is added here for numerical stability
by avoiding division by zero.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker);
REGISTER_OP_CPU_KERNEL(adagrad,
ops::AdagradOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/adagrad_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adagrad,
ops::AdagradOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class AdagradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
float epsilon = ctx.Attr<float>("epsilon");
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment"));
auto lr = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("LearningRate"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
moment_out.device(place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(place) =
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/rmsprop_op.h"
namespace paddle {
namespace operators {
class RmspropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("MeanSquare"),
"Input(MeanSquare) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
"Output(Momentum_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"),
"Output(MeanSquareOut) of RmspropOp should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and grad input of RmspropOp should have the same dimension.");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
"Learning Rate should be a scalar.");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim);
ctx->SetOutputDim("MeanSquareOut", param_dim);
}
};
class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RmspropOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated");
AddInput("MeanSquare",
"(Tensor, default Tensor<float>)"
" The mean square value that gets updated");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter");
AddInput("Moment",
"(Tensor, default Tensor<float>) The moment that gets updated");
AddOutput("ParamOut", "(Tensor) Output updated parameter value");
AddOutput("MomentOut", "(Tensor) Output updated moment");
AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value");
AddAttr<float>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
.SetDefault(1.0e-10f);
AddAttr<float>("decay",
"(float, default 0.9) "
"Discounting factor for coming gradient.")
.SetDefault(0.9f);
AddAttr<float>("momentum", "(float, default 0.0) Constant value")
.SetDefault(0.0f);
AddComment(R"DOC(
RMSprop
MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad
MomentOut = momentum * Moment +
LearningRate * Grad / sqrt(MeanSquareOut + epsilon)
ParamOut = Param - MomentOut
The original slides that proposed RMSprop: Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker);
REGISTER_OP_CPU_KERNEL(rmsprop,
ops::RmspropOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/rmsprop_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(rmsprop,
ops::RmspropOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class RmspropOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* moment_out = ctx.Output<Tensor>("MomentOut");
auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
auto grad = ctx.Input<Tensor>("Grad");
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
mean_square_out->mutable_data<T>(ctx.GetPlace());
float epsilon = ctx.Attr<float>("epsilon");
float rho = ctx.Attr<float>("decay");
float momentum = ctx.Attr<float>("momentum");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto g = EigenVector<T>::Flatten(*grad);
auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto p_out = EigenVector<T>::Flatten(*param_out);
auto mom_out = EigenVector<T>::Flatten(*moment_out);
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
mom_out.device(place) =
momentum * mom +
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
p_out.device(place) = p - mom_out;
}
};
} // namespace operators
} // namespace paddle
......@@ -230,6 +230,21 @@ All parameter, weight, gradient are variables in Paddle.
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
})
.def_static("infer_shape",
[](OpDescBind &op_desc, BlockDescBind &block) {
auto op = OpRegistry::CreateOp(*op_desc.Proto());
auto *op_with_kernel =
dynamic_cast<OperatorWithKernel *>(op.get());
if (op_with_kernel != nullptr) {
auto ctx = CompileTimeInferShapeContext(op_desc, block);
op_with_kernel->InferShape(&ctx);
} else {
PADDLE_THROW(
"OP(%s) is not type of OperatorWithKernel, "
"should not call this function",
op_desc.Type());
}
})
.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
......
import unittest
import numpy as np
from op_test import OpTest
class TestAdagradOp1(OpTest):
''' Test Adagrad operator with explicit attributes
'''
def setUp(self):
self.op_type = "adagrad"
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
lr = 0.01
epsilon = 1e-8
self.inputs = {
'Param': param,
'Grad': grad,
'Moment': moment,
'LearningRate': np.array([lr]).astype("float32")
}
self.attrs = {'epsilon': epsilon}
moment_out = moment + grad * grad
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
def test_check_output(self):
self.check_output()
class TestAdagradOp2(OpTest):
''' Test Adagrad operator with default attributes
'''
def setUp(self):
self.op_type = "adagrad"
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
lr = 0.01
epsilon = 1e-6
self.inputs = {
'Param': param,
'Grad': grad,
'Moment': moment,
'LearningRate': np.array([lr]).astype("float32")
}
self.attrs = {'epsilon': epsilon}
moment_out = moment + grad * grad
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestInferShape(unittest.TestCase):
def test_sum_op(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
shape = [10, 20]
# prepare input/output
x1 = block.new_var("x1")
x1.set_shape(shape)
x2 = block.new_var("x2")
x2.set_shape(shape)
out = block.new_var("out")
# prepare the operator
sum_op_desc = block.append_op()
sum_op_desc.set_type("sum")
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
core.Operator.infer_shape(sum_op_desc, block)
self.assertEqual(out.shape(), shape)
def test_mul_op(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
x_shape = [10, 20]
y_shape = [20, 30]
# prepare input/output
x1 = block.new_var("x")
x1.set_shape(x_shape)
x2 = block.new_var("y")
x2.set_shape(y_shape)
out = block.new_var("out")
# prepare the operator
mul_op_desc = block.append_op()
mul_op_desc.set_type("mul")
mul_op_desc.set_input("X", ["x"])
mul_op_desc.set_input("Y", ["y"])
mul_op_desc.set_output("Out", ["out"])
mul_op_desc.set_attr("x_num_col_dims", 1)
mul_op_desc.set_attr("y_num_col_dims", 1)
core.Operator.infer_shape(mul_op_desc, block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestRmspropOp1(OpTest):
''' Test RMSProp with explicit inputs
'''
def setUp(self):
self.op_type = "rmsprop"
param = np.random.random((123, 321)).astype("float32")
mean_square = np.random.random((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
epsilon = 1e-6
decay = 0.9
momentum = 0.0
self.inputs = {
'Param': param,
'MeanSquare': mean_square,
'LearningRate': learning_rate,
'Grad': grad,
'Moment': moment,
}
self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum}
ms_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + \
learning_rate * grad / np.sqrt(ms_out + epsilon)
param_out = param - moment_out
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': ms_out
}
def test_check_output(self):
self.check_output()
class TestRmspropOp2(OpTest):
'''Test RMSProp with defaukt values for attributes
'''
def setUp(self):
self.op_type = "rmsprop"
param = np.random.random((123, 321)).astype("float32")
mean_square = np.random.random((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
epsilon = 1.0e-10
decay = 0.9
momentum = 0.0
self.inputs = {
'Param': param,
'MeanSquare': mean_square,
'LearningRate': learning_rate,
'Grad': grad,
'Moment': moment,
}
ms_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + \
learning_rate * grad / np.sqrt(ms_out + epsilon)
param_out = param - moment_out
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': ms_out
}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册