提交 9620df44 编写于 作者: Y Yi Wang

Reformat paddle/operators/* strictly following Google Style Guide

上级 559b0224
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
...@@ -18,7 +18,7 @@ namespace paddle { ...@@ -18,7 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class AddOp : public OperatorWithKernel { class AddOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of AddOp must be two"); PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
...@@ -33,7 +33,7 @@ protected: ...@@ -33,7 +33,7 @@ protected:
}; };
class AddOpMaker : public OpProtoAndCheckerMaker { class AddOpMaker : public OpProtoAndCheckerMaker {
public: public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op"); AddInput("X", "The first input of add op");
...@@ -48,7 +48,7 @@ The equation is: Out = X + Y ...@@ -48,7 +48,7 @@ The equation is: Out = X + Y
}; };
class AddOpGrad : public OperatorWithKernel { class AddOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override {} void InferShape(const InferShapeContext &ctx) const override {}
}; };
......
...@@ -20,7 +20,7 @@ namespace operators { ...@@ -20,7 +20,7 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class AddKernel : public OpKernel { class AddKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input0 = context.Input<Tensor>(0); auto input0 = context.Input<Tensor>(0);
auto input1 = context.Input<Tensor>(1); auto input1 = context.Input<Tensor>(1);
......
...@@ -18,7 +18,7 @@ namespace paddle { ...@@ -18,7 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel { class OnehotCrossEntropyOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, PADDLE_ENFORCE(ctx.InputSize() == 2,
"Input size of OnehotCrossEntropyOp must be two"); "Input size of OnehotCrossEntropyOp must be two");
...@@ -37,7 +37,7 @@ protected: ...@@ -37,7 +37,7 @@ protected:
}; };
class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker { class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker {
public: public:
OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker) OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp"); AddInput("X", "The first input of OnehotCrossEntropyOp");
...@@ -54,8 +54,7 @@ OnehotCrossEntropy Operator. ...@@ -54,8 +54,7 @@ OnehotCrossEntropy Operator.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(onehot_cross_entropy, REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker); ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>); ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
...@@ -20,7 +20,7 @@ namespace operators { ...@@ -20,7 +20,7 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public OpKernel { class OnehotCrossEntropyOpKernel : public OpKernel {
public: public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); } constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }
void Compute(const ExecutionContext& ctx) const override { void Compute(const ExecutionContext& ctx) const override {
......
...@@ -18,31 +18,29 @@ namespace paddle { ...@@ -18,31 +18,29 @@ namespace paddle {
namespace operators { namespace operators {
class FullyConnectedOp : public NetOp { class FullyConnectedOp : public NetOp {
public: public:
void Init() override { void Init() override {
AddOp(OpRegistry::CreateOp("mul", AddOp(OpRegistry::CreateOp("mul",
{ {
Input("X"), Input("W"), Input("X"), Input("W"),
}, },
{Output("before_act")}, {Output("before_act")}, {}));
{}));
auto b = Input("b"); auto b = Input("b");
if (b != framework::kEmptyVarName) { if (b != framework::kEmptyVarName) {
AddOp(OpRegistry::CreateOp("rowwise_add", AddOp(OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")}, {Output("before_act"), Input("b")},
{Output("before_act")}, {Output("before_act")}, {}));
{}));
} }
auto activation = GetAttr<std::string>("activation"); auto activation = GetAttr<std::string>("activation");
AddOp(OpRegistry::CreateOp( AddOp(OpRegistry::CreateOp(activation, {Output("before_act")},
activation, {Output("before_act")}, {Output("Y")}, {})); {Output("Y")}, {}));
CompleteAddOp(false); CompleteAddOp(false);
} }
}; };
class FullyConnectedOpMaker : public OpProtoAndCheckerMaker { class FullyConnectedOpMaker : public OpProtoAndCheckerMaker {
public: public:
FullyConnectedOpMaker(OpProto *proto, OpAttrChecker *op_checker) FullyConnectedOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator"); AddInput("X", "the input of fc operator");
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel { class FillZerosLikeOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1UL, PADDLE_ENFORCE(ctx.InputSize() == 1UL,
"Input size of FillZerosLikeOp must be one."); "Input size of FillZerosLikeOp must be one.");
...@@ -36,7 +36,7 @@ protected: ...@@ -36,7 +36,7 @@ protected:
}; };
class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FillZerosLikeOpMaker(framework::OpProto *proto, FillZerosLikeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
...@@ -52,8 +52,7 @@ The output will have the same size with input. ...@@ -52,8 +52,7 @@ The output will have the same size with input.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(fill_zeros_like, REGISTER_OP(fill_zeros_like, paddle::operators::FillZerosLikeOp,
paddle::operators::FillZerosLikeOp,
paddle::operators::FillZerosLikeOpMaker); paddle::operators::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like, fill_zeros_like,
......
...@@ -22,7 +22,7 @@ namespace operators { ...@@ -22,7 +22,7 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class FillZerosLikeKernel : public framework::OpKernel { class FillZerosLikeKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* output = context.Output<framework::Tensor>(0); auto* output = context.Output<framework::Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
......
...@@ -18,7 +18,7 @@ namespace paddle { ...@@ -18,7 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class MeanOp : public OperatorWithKernel { class MeanOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one"); PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
...@@ -29,7 +29,7 @@ protected: ...@@ -29,7 +29,7 @@ protected:
}; };
class MeanOpMaker : public OpProtoAndCheckerMaker { class MeanOpMaker : public OpProtoAndCheckerMaker {
public: public:
MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker) MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
...@@ -39,7 +39,7 @@ public: ...@@ -39,7 +39,7 @@ public:
}; };
class MeanGradOp : public OperatorWithKernel { class MeanGradOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + framework::kGradVarSuffix) ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
->Resize(ctx.Input<Tensor>("X")->dims()); ->Resize(ctx.Input<Tensor>("X")->dims());
......
...@@ -20,7 +20,7 @@ namespace operators { ...@@ -20,7 +20,7 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MeanKernel : public OpKernel { class MeanKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0); auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
template <typename Place, typename T> template <typename Place, typename T>
class MeanGradKernel : public OpKernel { class MeanGradKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + framework::kGradVarSuffix); auto OG = context.Input<Tensor>("Out" + framework::kGradVarSuffix);
PADDLE_ENFORCE(framework::product(OG->dims()) == 1, PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
......
...@@ -18,7 +18,7 @@ namespace paddle { ...@@ -18,7 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class MulOp : public OperatorWithKernel { class MulOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs"); PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs");
auto dim0 = ctx.Input<Tensor>(0)->dims(); auto dim0 = ctx.Input<Tensor>(0)->dims();
...@@ -34,7 +34,7 @@ protected: ...@@ -34,7 +34,7 @@ protected:
}; };
class MulOpMaker : public OpProtoAndCheckerMaker { class MulOpMaker : public OpProtoAndCheckerMaker {
public: public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op"); AddInput("X", "The first input of mul op");
...@@ -49,7 +49,7 @@ The equation is: Out = X * Y ...@@ -49,7 +49,7 @@ The equation is: Out = X * Y
}; };
class MulOpGrad : public OperatorWithKernel { class MulOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override {} void InferShape(const InferShapeContext &ctx) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "MulGrad"; LOG(INFO) << "MulGrad";
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MulKernel : public OpKernel { class MulKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = { Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}}; {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
......
...@@ -40,7 +40,7 @@ namespace operators { ...@@ -40,7 +40,7 @@ namespace operators {
* it defines. * it defines.
*/ */
class NetOp : public framework::OperatorBase { class NetOp : public framework::OperatorBase {
public: public:
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch * before every mini-batch
...@@ -90,7 +90,7 @@ public: ...@@ -90,7 +90,7 @@ public:
std::vector<std::shared_ptr<OperatorBase>> ops_; std::vector<std::shared_ptr<OperatorBase>> ops_;
private: private:
bool add_op_done_{false}; bool add_op_done_{false};
template <typename T, typename KeyType> template <typename T, typename KeyType>
......
...@@ -12,7 +12,7 @@ static int infer_shape_cnt = 0; ...@@ -12,7 +12,7 @@ static int infer_shape_cnt = 0;
static int run_cnt = 0; static int run_cnt = 0;
class TestOp : public OperatorBase { class TestOp : public OperatorBase {
public: public:
void InferShape(const framework::Scope& scope) const override { void InferShape(const framework::Scope& scope) const override {
++infer_shape_cnt; ++infer_shape_cnt;
} }
...@@ -23,7 +23,7 @@ public: ...@@ -23,7 +23,7 @@ public:
}; };
class EmptyOp : public OperatorBase { class EmptyOp : public OperatorBase {
public: public:
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
......
...@@ -28,14 +28,12 @@ namespace operators { ...@@ -28,14 +28,12 @@ namespace operators {
namespace rnn { namespace rnn {
void SegmentInputs(const std::vector<Scope*>& step_scopes, void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const std::vector<Link>& inlinks, const size_t seq_len,
const size_t seq_len,
bool infer_shape_mode) { bool infer_shape_mode) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) { for (size_t i = 0; i < inlinks.size(); ++i) {
auto input_var = step_scopes[0]->FindVar(inlinks[i].external); auto input_var = step_scopes[0]->FindVar(inlinks[i].external);
PADDLE_ENFORCE(input_var != nullptr, PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
"input link [%s] is not in scope.",
inlinks[i].external); inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>(); Tensor* input = input_var->GetMutable<Tensor>();
framework::DDim dims = input->dims(); framework::DDim dims = input->dims();
...@@ -54,13 +52,11 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -54,13 +52,11 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
} }
void ConcatOutputs(const std::vector<Scope*>& step_scopes, void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks, const std::vector<Link>& outlinks, const size_t seq_len,
const size_t seq_len,
bool infer_shape_mode) { bool infer_shape_mode) {
for (size_t i = 0; i < outlinks.size(); i++) { for (size_t i = 0; i < outlinks.size(); i++) {
auto output_var = step_scopes[0]->FindVar(outlinks[i].external); auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
PADDLE_ENFORCE(output_var != nullptr, PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
"output link [%s] is not in scope.",
outlinks[i].external); outlinks[i].external);
Tensor* output = output_var->GetMutable<Tensor>(); Tensor* output = output_var->GetMutable<Tensor>();
if (infer_shape_mode) { if (infer_shape_mode) {
...@@ -87,22 +83,16 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes, ...@@ -87,22 +83,16 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
void LinkMemories(const std::vector<Scope*>& scopes, void LinkMemories(const std::vector<Scope*>& scopes,
const std::vector<rnn::MemoryAttr>& memories, const std::vector<rnn::MemoryAttr>& memories,
const size_t step_id, const size_t step_id, const int offset,
const int offset,
bool infer_shape_mode) { bool infer_shape_mode) {
PADDLE_ENFORCE(step_id < scopes.size(), PADDLE_ENFORCE(step_id < scopes.size(),
"step [%d] is out of range of step scopes' size [%d]", "step [%d] is out of range of step scopes' size [%d]", step_id,
step_id,
scopes.size()); scopes.size());
PADDLE_ENFORCE(static_cast<int>(step_id) + offset >= 0, PADDLE_ENFORCE(static_cast<int>(step_id) + offset >= 0,
"offset [%d] must be large than -[%d]", "offset [%d] must be large than -[%d]", offset, step_id);
offset,
step_id);
PADDLE_ENFORCE(step_id + offset < scopes.size(), PADDLE_ENFORCE(step_id + offset < scopes.size(),
"offset [%d] is out of range, it must be less than (%d - %d)", "offset [%d] is out of range, it must be less than (%d - %d)",
offset, offset, scopes.size(), step_id);
scopes.size(),
step_id);
auto scope = scopes[step_id]; auto scope = scopes[step_id];
auto linked_scope = scopes[step_id + offset]; auto linked_scope = scopes[step_id + offset];
for (auto& attr : memories) { for (auto& attr : memories) {
...@@ -116,8 +106,7 @@ void LinkMemories(const std::vector<Scope*>& scopes, ...@@ -116,8 +106,7 @@ void LinkMemories(const std::vector<Scope*>& scopes,
} }
} }
void InitArgument(const ArgumentName& name, void InitArgument(const ArgumentName& name, Argument* arg,
Argument* arg,
const OperatorBase& op) { const OperatorBase& op) {
arg->step_net = op.Input(name.step_net); arg->step_net = op.Input(name.step_net);
arg->step_scopes = op.Output(name.step_scopes); arg->step_scopes = op.Output(name.step_scopes);
...@@ -126,8 +115,7 @@ void InitArgument(const ArgumentName& name, ...@@ -126,8 +115,7 @@ void InitArgument(const ArgumentName& name,
auto inlink_alias = op.GetAttr<std::vector<std::string>>(name.inlink_alias); auto inlink_alias = op.GetAttr<std::vector<std::string>>(name.inlink_alias);
PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(), PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(),
"the size of inlinks and inlink_alias don't match:%d,%d", "the size of inlinks and inlink_alias don't match:%d,%d",
inlinks.size(), inlinks.size(), inlink_alias.size());
inlink_alias.size());
for (size_t i = 0; i < inlinks.size(); ++i) { for (size_t i = 0; i < inlinks.size(); ++i) {
rnn::Link link; rnn::Link link;
link.external = inlinks[i]; link.external = inlinks[i];
...@@ -139,8 +127,7 @@ void InitArgument(const ArgumentName& name, ...@@ -139,8 +127,7 @@ void InitArgument(const ArgumentName& name,
auto outlink_alias = op.GetAttr<std::vector<std::string>>(name.outlink_alias); auto outlink_alias = op.GetAttr<std::vector<std::string>>(name.outlink_alias);
PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(), PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(),
"the size of outlinks and outlink_alias don't match:%d,%d", "the size of outlinks and outlink_alias don't match:%d,%d",
outlinks.size(), outlinks.size(), outlink_alias.size());
outlink_alias.size());
for (size_t i = 0; i < outlinks.size(); ++i) { for (size_t i = 0; i < outlinks.size(); ++i) {
rnn::Link link; rnn::Link link;
link.external = outlinks[i]; link.external = outlinks[i];
...@@ -156,12 +143,10 @@ void InitArgument(const ArgumentName& name, ...@@ -156,12 +143,10 @@ void InitArgument(const ArgumentName& name,
PADDLE_ENFORCE(memories.size() == boot_memories.size(), PADDLE_ENFORCE(memories.size() == boot_memories.size(),
"the size of memories, boot_memories don't match:%d,%d", "the size of memories, boot_memories don't match:%d,%d",
memories.size(), memories.size(), boot_memories.size());
boot_memories.size());
PADDLE_ENFORCE(pre_memories.size() == boot_memories.size(), PADDLE_ENFORCE(pre_memories.size() == boot_memories.size(),
"the size of pre_memories, boot_memories don't match:%d,%d", "the size of pre_memories, boot_memories don't match:%d,%d",
pre_memories.size(), pre_memories.size(), boot_memories.size());
boot_memories.size());
PADDLE_ENFORCE(memories.size() > 0, "more than 1 memories should be set"); PADDLE_ENFORCE(memories.size() > 0, "more than 1 memories should be set");
for (size_t i = 0; i < memories.size(); ++i) { for (size_t i = 0; i < memories.size(); ++i) {
...@@ -181,39 +166,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { ...@@ -181,39 +166,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
->dims()[0]; ->dims()[0];
CreateScopes(scope); CreateScopes(scope);
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs( rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); true /*infer_shape_mode*/);
InitMemories(step_scopes[0], true /*infer_shape_mode*/); InitMemories(step_scopes[0], true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (size_t i = 0; i < seq_len_; i++) { for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) { if (i > 0) {
rnn::LinkMemories( rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->InferShape(*step_scopes[i]); net->GetMutable<NetOp>()->InferShape(*step_scopes[i]);
} }
rnn::ConcatOutputs( rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
void RecurrentAlgorithm::Run(const Scope& scope, void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs( rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/); InitMemories(step_scopes[0], false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
for (size_t step_id = 0; step_id < seq_len_; step_id++) { for (size_t step_id = 0; step_id < seq_len_; step_id++) {
if (step_id > 0) { if (step_id > 0) {
rnn::LinkMemories( rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/); false /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx); net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
} }
rnn::ConcatOutputs( rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); false /*infer_shape_mode*/);
} }
void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
...@@ -245,8 +230,7 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope, ...@@ -245,8 +230,7 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
for (auto& attr : arg_->memories) { for (auto& attr : arg_->memories) {
Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>(); Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists", "memory [%s]'s boot variable [%s] not exists", attr.var,
attr.var,
attr.boot_var); attr.boot_var);
Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>(); Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>();
if (infer_shape_mode) { if (infer_shape_mode) {
...@@ -257,25 +241,15 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope, ...@@ -257,25 +241,15 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
} }
} }
const rnn::ArgumentName RecurrentOp::kArgName{"step_net", const rnn::ArgumentName RecurrentOp::kArgName{
"step_scopes", "step_net", "step_scopes", "inlinks",
"inlinks", "outlinks", "inlink_alias", "outlink_alias",
"outlinks", "memories", "pre_memories", "boot_memories"};
"inlink_alias",
"outlink_alias", const rnn::ArgumentName RecurrentGradientOp::kArgName{
"memories", "step_net", "step_scopes", "outlink@grad",
"pre_memories", "inlink@grad", "inlink_alias", "outlink_alias",
"boot_memories"}; "memories", "pre_memories", "boot_memories@grad"};
const rnn::ArgumentName RecurrentGradientOp::kArgName{"step_net",
"step_scopes",
"outlink@grad",
"inlink@grad",
"inlink_alias",
"outlink_alias",
"memories",
"pre_memories",
"boot_memories@grad"};
void RecurrentOp::Init() { void RecurrentOp::Init() {
OperatorBase::Init(); OperatorBase::Init();
...@@ -285,7 +259,7 @@ void RecurrentOp::Init() { ...@@ -285,7 +259,7 @@ void RecurrentOp::Init() {
} }
class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
RecurrentAlgorithmProtoAndCheckerMaker(OpProto* proto, RecurrentAlgorithmProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker) OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
...@@ -316,31 +290,29 @@ public: ...@@ -316,31 +290,29 @@ public:
void RecurrentGradientAlgorithm::Run( void RecurrentGradientAlgorithm::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const { const Scope& scope, const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs( rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories( rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/); false /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx); net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
} }
LinkBootMemoryGradients(step_scopes[0], false); LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs( rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); false /*infer_shape_mode*/);
} }
void RecurrentGradientAlgorithm::LinkBootMemoryGradients( void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
Scope* step_scope, bool infer_shape_mode) const { Scope* step_scope, bool infer_shape_mode) const {
for (auto& attr : arg_->memories) { for (auto& attr : arg_->memories) {
PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr, PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
"memory variable [%s] does not exists", "memory variable [%s] does not exists", attr.var);
attr.var);
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"boot variable [%s] does not exists", "boot variable [%s] does not exists", attr.boot_var);
attr.boot_var);
Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>(); Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
Tensor* boot_mem_grad = Tensor* boot_mem_grad =
step_scope->NewVar(attr.boot_var)->GetMutable<Tensor>(); step_scope->NewVar(attr.boot_var)->GetMutable<Tensor>();
...@@ -357,19 +329,19 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { ...@@ -357,19 +329,19 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
->GetMutable<Tensor>() ->GetMutable<Tensor>()
->dims()[0]; ->dims()[0];
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs( rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories( rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]); net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]);
} }
rnn::ConcatOutputs( rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); true /*infer_shape_mode*/);
LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/); LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
} }
...@@ -383,6 +355,5 @@ void RecurrentGradientOp::Init() { ...@@ -383,6 +355,5 @@ void RecurrentGradientOp::Init() {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(recurrent_op, REGISTER_OP(recurrent_op, paddle::operators::RecurrentOp,
paddle::operators::RecurrentOp,
paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker);
...@@ -69,23 +69,19 @@ struct ArgumentName { ...@@ -69,23 +69,19 @@ struct ArgumentName {
* Prepare inputs for each step net. * Prepare inputs for each step net.
*/ */
void SegmentInputs(const std::vector<framework::Scope*>& step_scopes, void SegmentInputs(const std::vector<framework::Scope*>& step_scopes,
const std::vector<Link>& inlinks, const std::vector<Link>& inlinks, const size_t seq_len,
const size_t seq_len,
bool infer_shape_mode); bool infer_shape_mode);
/** /**
* Process outputs of step nets and merge to variables. * Process outputs of step nets and merge to variables.
*/ */
void ConcatOutputs(const std::vector<framework::Scope*>& step_scopes, void ConcatOutputs(const std::vector<framework::Scope*>& step_scopes,
const std::vector<Link>& outlinks, const std::vector<Link>& outlinks, const size_t seq_len,
const size_t seq_len,
bool infer_shape_mode); bool infer_shape_mode);
void LinkMemories(const std::vector<framework::Scope*>& step_scopes, void LinkMemories(const std::vector<framework::Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const std::vector<MemoryAttr>& memories, const size_t step_id,
const size_t step_id, const int offset, bool infer_shape_mode);
const int offset,
bool infer_shape_mode);
void InitArgument(const ArgumentName& name, Argument* arg); void InitArgument(const ArgumentName& name, Argument* arg);
...@@ -100,7 +96,7 @@ void InitArgument(const ArgumentName& name, Argument* arg); ...@@ -100,7 +96,7 @@ void InitArgument(const ArgumentName& name, Argument* arg);
// Refer to: https://arxiv.org/pdf/1502.02367.pdf // Refer to: https://arxiv.org/pdf/1502.02367.pdf
class RecurrentAlgorithm { class RecurrentAlgorithm {
public: public:
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const; const platform::DeviceContext& dev_ctx) const;
...@@ -111,7 +107,7 @@ public: ...@@ -111,7 +107,7 @@ public:
*/ */
void InferShape(const framework::Scope& scope) const; void InferShape(const framework::Scope& scope) const;
protected: protected:
/* /*
* The step scopes will be stored in the father scope as a variable. * The step scopes will be stored in the father scope as a variable.
* *
...@@ -128,7 +124,7 @@ protected: ...@@ -128,7 +124,7 @@ protected:
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
private: private:
std::unique_ptr<rnn::Argument> arg_; std::unique_ptr<rnn::Argument> arg_;
mutable size_t seq_len_; mutable size_t seq_len_;
}; };
...@@ -144,7 +140,7 @@ class RecurrentGradientAlgorithm { ...@@ -144,7 +140,7 @@ class RecurrentGradientAlgorithm {
* lot, and the latter is a wrapper acts like an dapter for it to make RNN an * lot, and the latter is a wrapper acts like an dapter for it to make RNN an
* operator. * operator.
*/ */
public: public:
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); } void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
...@@ -158,20 +154,20 @@ public: ...@@ -158,20 +154,20 @@ public:
*/ */
void InferShape(const framework::Scope& scope) const; void InferShape(const framework::Scope& scope) const;
protected: protected:
inline const std::vector<framework::Scope*>& GetStepScopes( inline const std::vector<framework::Scope*>& GetStepScopes(
const framework::Scope& scope) const { const framework::Scope& scope) const {
return *scope.FindVar(arg_->step_scopes) return *scope.FindVar(arg_->step_scopes)
->GetMutable<std::vector<framework::Scope*>>(); ->GetMutable<std::vector<framework::Scope*>>();
} }
private: private:
std::unique_ptr<rnn::Argument> arg_; std::unique_ptr<rnn::Argument> arg_;
mutable size_t seq_len_; mutable size_t seq_len_;
}; };
class RecurrentOp final : public framework::OperatorBase { class RecurrentOp final : public framework::OperatorBase {
public: public:
void Init() override; void Init() override;
/** /**
...@@ -188,12 +184,12 @@ public: ...@@ -188,12 +184,12 @@ public:
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
private: private:
RecurrentAlgorithm alg_; RecurrentAlgorithm alg_;
}; };
class RecurrentGradientOp final : public framework::OperatorBase { class RecurrentGradientOp final : public framework::OperatorBase {
public: public:
void Init() override; void Init() override;
/** /**
...@@ -210,7 +206,7 @@ public: ...@@ -210,7 +206,7 @@ public:
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
private: private:
RecurrentGradientAlgorithm alg_; RecurrentGradientAlgorithm alg_;
}; };
......
...@@ -29,7 +29,7 @@ using framework::make_ddim; ...@@ -29,7 +29,7 @@ using framework::make_ddim;
using framework::DDim; using framework::DDim;
class RecurrentOpTest : public ::testing::Test { class RecurrentOpTest : public ::testing::Test {
protected: protected:
virtual void SetUp() override { virtual void SetUp() override {
CreateGlobalVariables(); CreateGlobalVariables();
CreateStepNet(); CreateStepNet();
...@@ -174,7 +174,7 @@ TEST_F(RecurrentOpTest, Run) { ...@@ -174,7 +174,7 @@ TEST_F(RecurrentOpTest, Run) {
} }
class RecurrentGradientAlgorithmTest : public ::testing::Test { class RecurrentGradientAlgorithmTest : public ::testing::Test {
protected: protected:
virtual void SetUp() override { virtual void SetUp() override {
CreateGlobalVariables(); CreateGlobalVariables();
CreateStepScopes(); CreateStepScopes();
...@@ -277,13 +277,11 @@ protected: ...@@ -277,13 +277,11 @@ protected:
LOG(INFO) << "create variable step_net"; LOG(INFO) << "create variable step_net";
Variable* var = scope_.NewVar("step_net"); Variable* var = scope_.NewVar("step_net");
auto net = var->GetMutable<NetOp>(); auto net = var->GetMutable<NetOp>();
net->AddOp(OpRegistry::CreateOp("mul", net->AddOp(OpRegistry::CreateOp("mul", {"rnn/h_pre", "rnn/w", "rnn/s_grad"},
{"rnn/h_pre", "rnn/w", "rnn/s_grad"}, {"rnn/h_pre_grad", "rnn/w_grad"}, {}));
{"rnn/h_pre_grad", "rnn/w_grad"},
{}));
net->AddOp(OpRegistry::CreateOp( net->AddOp(OpRegistry::CreateOp("add_two", {"rnn/h_grad"},
"add_two", {"rnn/h_grad"}, {"rnn/x_grad", "rnn/s_grad"}, {})); {"rnn/x_grad", "rnn/s_grad"}, {}));
net->CompleteAddOp(); net->CompleteAddOp();
} }
...@@ -297,9 +295,7 @@ protected: ...@@ -297,9 +295,7 @@ protected:
inlink.internal = "rnn/x"; inlink.internal = "rnn/x";
auto step_scopes = auto step_scopes =
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>(); scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
rnn::SegmentInputs(*step_scopes, rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10,
std::vector<rnn::Link>{inlink},
10,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
...@@ -314,8 +310,8 @@ protected: ...@@ -314,8 +310,8 @@ protected:
auto step_scopes = auto step_scopes =
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>(); scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
for (int i = 1; i < 10; ++i) { for (int i = 1; i < 10; ++i) {
rnn::LinkMemories( rnn::LinkMemories(*step_scopes, memories, i, -1,
*step_scopes, memories, i, -1, true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
} }
......
...@@ -17,7 +17,7 @@ namespace paddle { ...@@ -17,7 +17,7 @@ namespace paddle {
namespace operators { namespace operators {
class RowWiseAddOp : public OperatorWithKernel { class RowWiseAddOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2UL, PADDLE_ENFORCE(ctx.InputSize() == 2UL,
"Two inputs is needed by rowwise add"); "Two inputs is needed by rowwise add");
...@@ -33,7 +33,7 @@ protected: ...@@ -33,7 +33,7 @@ protected:
}; };
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix"); AddInput("X", "The left input of row-wise add op, must be matrix");
......
...@@ -20,7 +20,7 @@ namespace operators { ...@@ -20,7 +20,7 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class RowWiseAddKernel : public OpKernel { class RowWiseAddKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto out = context.Output<Tensor>(0); auto out = context.Output<Tensor>(0);
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
......
...@@ -18,7 +18,7 @@ namespace paddle { ...@@ -18,7 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class SGDOp : public OperatorWithKernel { class SGDOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of SGDOp must be two"); PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of SGDOp must be one"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of SGDOp must be one");
...@@ -32,7 +32,7 @@ protected: ...@@ -32,7 +32,7 @@ protected:
}; };
class SGDOpMaker : public OpProtoAndCheckerMaker { class SGDOpMaker : public OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(OpProto *proto, OpAttrChecker *op_checker) SGDOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter"); AddInput("param", "input parameter");
......
...@@ -20,7 +20,7 @@ namespace operators { ...@@ -20,7 +20,7 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public OpKernel { class SGDOpKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& ctx) const override { void Compute(const ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<Tensor>("grad");
......
...@@ -17,7 +17,7 @@ namespace paddle { ...@@ -17,7 +17,7 @@ namespace paddle {
namespace operators { namespace operators {
class SigmoidOp : public OperatorWithKernel { class SigmoidOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input"); PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output");
...@@ -26,7 +26,7 @@ protected: ...@@ -26,7 +26,7 @@ protected:
}; };
class SigmoidOpMaker : public OpProtoAndCheckerMaker { class SigmoidOpMaker : public OpProtoAndCheckerMaker {
public: public:
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "sigmoid input"); AddInput("X", "sigmoid input");
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
}; };
class SigmoidOpGrad : public OperatorWithKernel { class SigmoidOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override {} void InferShape(const InferShapeContext &ctx) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "SigmoidGrad"; LOG(INFO) << "SigmoidGrad";
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidKernel : public OpKernel { class SigmoidKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0); auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
......
...@@ -18,7 +18,7 @@ namespace paddle { ...@@ -18,7 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class SoftmaxOp : public OperatorWithKernel { class SoftmaxOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1UL, PADDLE_ENFORCE(ctx.InputSize() == 1UL,
"Only one input is need for softmax"); "Only one input is need for softmax");
...@@ -31,7 +31,7 @@ protected: ...@@ -31,7 +31,7 @@ protected:
}; };
class SoftmaxOpMaker : public OpProtoAndCheckerMaker { class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
public: public:
SoftmaxOpMaker(OpProto *proto, OpAttrChecker *op_checker) SoftmaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "input of softmax"); AddInput("X", "input of softmax");
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
}; };
class SoftmaxOpGrad : public OperatorWithKernel { class SoftmaxOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 3UL, PADDLE_ENFORCE(ctx.InputSize() == 3UL,
"Input of SoftmaxOpGrad should be 3, X, Y, YG"); "Input of SoftmaxOpGrad should be 3, X, Y, YG");
......
...@@ -24,7 +24,7 @@ namespace operators { ...@@ -24,7 +24,7 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxKernel : public OpKernel { class SoftmaxKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input = context.Input<Tensor>("X"); auto input = context.Input<Tensor>("X");
auto output = context.Output<Tensor>("Y"); auto output = context.Output<Tensor>("Y");
...@@ -63,7 +63,7 @@ public: ...@@ -63,7 +63,7 @@ public:
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxGradKernel : public OpKernel { class SoftmaxGradKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>(); std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();
......
...@@ -26,21 +26,16 @@ using OperatorBase = framework::OperatorBase; ...@@ -26,21 +26,16 @@ using OperatorBase = framework::OperatorBase;
using InferShapeContext = framework::InferShapeContext; using InferShapeContext = framework::InferShapeContext;
using ExecutionContext = framework::ExecutionContext; using ExecutionContext = framework::ExecutionContext;
using Variable = framework::Variable; using Variable = framework::Variable;
template <typename T, template <typename T, int MajorType = Eigen::RowMajor,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>; using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, template <typename T, int MajorType = Eigen::RowMajor,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, template <typename T, int MajorType = Eigen::RowMajor,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, template <typename T, size_t D, int MajorType = Eigen::RowMajor,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册