提交 bd369c35 编写于 作者: D dongzhihong

"remove type alias header file"

上级 72fb86a2
...@@ -12,14 +12,17 @@ ...@@ -12,14 +12,17 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #include "paddle/operators/net_op.h"
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FullyConnectedOp : public framework::NetOp { using OpRegistry = framework::OpRegistry;
class FullyConnectedOp : public NetOp {
public: public:
void Init() override { void Init() override {
AddOp(OpRegistry::CreateOp("mul", AddOp(OpRegistry::CreateOp("mul",
......
...@@ -14,17 +14,19 @@ ...@@ -14,17 +14,19 @@
#include "paddle/operators/recurrent_op.h" #include "paddle/operators/recurrent_op.h"
#include <glog/logging.h>
#include <cstring> #include <cstring>
#include <sstream> #include <sstream>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Scope = framework::Scope;
using Variable = framework::Variable;
using Tensor = framework::Tensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const { void RecurrentAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope.FindVar((arg_->inlinks[0]).external) seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<Tensor>() ->GetMutable<Tensor>()
...@@ -135,10 +137,11 @@ void RecurrentOp::Init() { ...@@ -135,10 +137,11 @@ void RecurrentOp::Init() {
alg_.Init(std::move(arg)); alg_.Init(std::move(arg));
} }
class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class RecurrentAlgorithmProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
public: public:
RecurrentAlgorithmProtoAndCheckerMaker(OpProto* proto, RecurrentAlgorithmProtoAndCheckerMaker(framework::OpProto* proto,
OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName; const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto // inputs and outputs stored in proto
......
...@@ -18,7 +18,9 @@ namespace paddle { ...@@ -18,7 +18,9 @@ namespace paddle {
namespace operators { namespace operators {
namespace rnn { namespace rnn {
namespace fmw = paddle::framework; namespace f = paddle::framework;
using Tensor = framework::Tensor;
void SegmentInputs(const std::vector<Scope*>& step_scopes, void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const size_t seq_len, const std::vector<Link>& inlinks, const size_t seq_len,
...@@ -30,10 +32,10 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -30,10 +32,10 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
inlinks[i].external); inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>(); Tensor* input = input_var->GetMutable<Tensor>();
fmw::DDim dims = input->dims(); f::DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len, PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length"); "all the inlinks must have same length");
fmw::DDim step_dims = slice_ddim(dims, 1, dims.size()); f::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input = Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>(); step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
...@@ -58,11 +60,10 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes, ...@@ -58,11 +60,10 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal); auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope", PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
outlinks[i].internal); outlinks[i].internal);
fmw::DDim step_dims = f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims();
step_scope_var->template GetMutable<Tensor>()->dims();
std::vector<int> dims_vec = vectorize(step_dims); std::vector<int> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len); dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(fmw::make_ddim(dims_vec)); output->Resize(f::make_ddim(dims_vec));
} else { } else {
output->mutable_data<float>(platform::CPUPlace()); output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
...@@ -104,7 +105,7 @@ void LinkMemories(const std::vector<Scope*>& scopes, ...@@ -104,7 +105,7 @@ void LinkMemories(const std::vector<Scope*>& scopes,
} }
void InitArgument(const ArgumentName& name, Argument* arg, void InitArgument(const ArgumentName& name, Argument* arg,
const OperatorBase& op) { const framework::OperatorBase& op) {
arg->step_net = op.Input(name.step_net); arg->step_net = op.Input(name.step_net);
arg->step_scopes = op.Output(name.step_scopes); arg->step_scopes = op.Output(name.step_scopes);
......
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
#include <string> #include <string>
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace rnn { namespace rnn {
using Scope = framework::Scope;
/** /**
* Memory of a RNN (same as the role of `Momory` in PaddlePaddle). * Memory of a RNN (same as the role of `Momory` in PaddlePaddle).
* *
...@@ -86,7 +87,7 @@ void LinkMemories(const std::vector<Scope*>& step_scopes, ...@@ -86,7 +87,7 @@ void LinkMemories(const std::vector<Scope*>& step_scopes,
const int offset, bool infer_shape_mode); const int offset, bool infer_shape_mode);
void InitArgument(const ArgumentName& name, Argument* arg, void InitArgument(const ArgumentName& name, Argument* arg,
const OperatorBase& op); const framework::OperatorBase& op);
} // namespace rnn } // namespace rnn
} // namespace operators } // namespace operators
......
...@@ -41,9 +41,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -41,9 +41,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class SoftmaxOpGrad : public OperatorWithKernel { class SoftmaxOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 3UL, PADDLE_ENFORCE(ctx.InputSize() == 3UL,
"Input of SoftmaxOpGrad should be 3, X, Y, YG"); "Input of SoftmaxOpGrad should be 3, X, Y, YG");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL, PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
......
...@@ -62,9 +62,9 @@ class SoftmaxKernel : public framework::OpKernel { ...@@ -62,9 +62,9 @@ class SoftmaxKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxGradKernel : public OpKernel { class SoftmaxGradKernel : public framework::OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>(); std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();
auto Y = context.Input<Tensor>("Y"); auto Y = context.Input<Tensor>("Y");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
using OpKernel = framework::OpKernel;
using OperatorBase = framework::OperatorBase;
using InferShapeContext = framework::InferShapeContext;
using ExecutionContext = framework::ExecutionContext;
using Variable = framework::Variable;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using Scope = framework::Scope;
using OperatorWithKernel = framework::OperatorWithKernel;
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace;
using OpRegistry = framework::OpRegistry;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册