提交 35086d06 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into get_places_op

...@@ -312,3 +312,9 @@ sequence_softmax ...@@ -312,3 +312,9 @@ sequence_softmax
.. autofunction:: paddle.v2.fluid.layers.sequence_softmax .. autofunction:: paddle.v2.fluid.layers.sequence_softmax
:noindex: :noindex:
reduce_sum
---------
.. autofunction:: paddle.v2.fluid.layers.reduce_sum
:noindex:
# Design Doc: Execute the Program with Multi CPU
## Abstract
This Design Doc propose an approach to make the user-defined Op graph
running with multi-CPU, we will use an auto transpiler to convert the user-defined
Op graph to a multi-CPU Op graph, and run `ParallelDo` Op to run the graph.
## Transpiler
<img src="src/multi-threads/single-thread@3x.png" width="300">
After converted:
<img src="src/multi-threads/multi-threads@3x.png" width="1000">
## Implement
- `Multi-CPU Transpiler` will convert the graph to a multi-CPU graph
which would be executed with multi-threads.
- `BlockingCounter` will `Init/Decrement` an atomic counter, and Blocking `Wait`
for the atomic counter become `0`:
```cpp
BlockingCounter bc(thread_count);
for (int i = 0; i < thread_count; ++i) {
thread_pool->Start([&bc] {bc.DecrementCount(); })
}
bc.Wait();
```
- `ParallelDo` Operator
- Initialize a thread pool which is a Singleton.
- Use a block id as the input, and create run the specify Block on independent scope
with multi-threads.
- Initialize a `BlockingCounter` instance and wait until all threads are done.
- `Split` Operator will split the Input Tensor into a TensorArray.
- `Merge` merge all the gradients which calculated in different threads
with `mean/sum/max/min...` method, and then run the Optimizer Op to optimize `W`.
## TODO
- Improve the optimizer stage with multi-threads, since we could
assign the parameters to the different threads and execute
optimizer with multi-threads.
## Background
Every operator has many kernels because there are multiple data types, places, data layout that Fluid supports. We use the `KernelType` to describe kernel types that operators can hold.
The `KernelType` is as follows.
```
struct KernelType {
Place place_;
DataType data_type_;
LayoutType layout_;
};
```
The `place_` is a descriptor of the device and the computational library, e.g., `MKLDNNPlace`, `CUDAPlace`.
The `data_type_` is the data type that this kernel performs on, e.g., `FP32`, `INT64`. Note that one kernel may have inputs with different data types. However, it will be a major `data_type`. For example, the `cross_entropy` takes `int64` as it label, and `double`/`float` as its input logit and output cost. The major `data_type` of `cross_entropy` is `float`/`double`.
The `layout` is useful for some computational library. One example is that MKLDNN uses many kinds of layout, such as `nChw8c`. Each kind of layout will invoke the different kernel.
## Problem
We register a kernel for every operator and every kernel type ideally. However, it is impracticable for the following situations.
1. Some operators, like CRF, are complicated and inefficient to be implemented on GPU. The CRF operator will only have a CPU kernel.
2. Some operators will take too many memory. It is better to force them into CPU. However, the rest of operators in this neural network will be performed on GPU, i.e., model parallel problem.
3. Some layout and place are particular. One example is that MKLDNN uses `nChw8` and there is no other library uses `nChw8c`.
Problems under these situations are similar. We can formalise this problem as follow.
We register kernels with types $KT = \{kt_1, kt_2, kt_3, ...\}$ for one operator. The inputs of this operator should be run on kernel type $kt_{?}$, which the $kt_{?} \notin KT$. How to cast the input of this operator from $kt_{?}$ to any of kernel type in $KT$.
## Solution
It is clearly that transforming inputs of an operator toadapt another kernel type is not related to the particular operator. So we should register these transformation methods as global methods.
We can infer a kernel type from the inputs of an operators. We let this kernel type as `actual kernel type`, which means this kernel type is the actually kernel type that operator should be performed.
We can get a kernel type by 1) The configuration of operator description. (Users may want to force use `MKL` for `conv` operator). 2) The place of the current executor. (Executor is running on GPU). This kernel type is what we expect the operator will be performed on. We let this kernel type as `expect kernel type`.
We transform the input data from `actual` to `expect` if the expect kernel type is not as same as actual kernel type.
The algorithm is described as follow
```cpp
using DataTransformationFN = std::function<void(const Tensor& in, Tensor* out)>;
using KernelTypePair = std::pair<KernelType, KernelType>;
map<KernelTypePair, DataTransformationFN> g_data_transformation_;
void OpWithKernel::Run() {
vec<Tensor> inputs = ...
auto actual_kernel_type = GetActualKernelType(inputs);
// The expected kernel type is related to actual kernel type.
// For the most operators, the expected kernel type is as same as
// actual kernel type.
//
// So we pass `actual_kernel_type` as a parameter of
// GetExpectedKernelType
auto expect_kernel_type = GetExpectedKernelType(actual_kernel_type);
auto trans = g_data_transformation_[{actual_kernel_type, expect_kernel_type}];
kernel.run(trans(inputs));
}
```
...@@ -128,7 +128,7 @@ PaddlePaddle Book是为用户和开发者制作的一个交互式的Jupyter Note ...@@ -128,7 +128,7 @@ PaddlePaddle Book是为用户和开发者制作的一个交互式的Jupyter Note
AVX是一种CPU指令集,可以加速PaddlePaddle的计算。最新的PaddlePaddle Docker镜像默认 AVX是一种CPU指令集,可以加速PaddlePaddle的计算。最新的PaddlePaddle Docker镜像默认
是开启AVX编译的,所以,如果您的电脑不支持AVX,需要单独 是开启AVX编译的,所以,如果您的电脑不支持AVX,需要单独
`编译 <./build_from_source_cn.rst>`_ PaddlePaddle为no-avx版本。 `编译 <./build_from_source_cn.html>`_ PaddlePaddle为no-avx版本。
以下指令能检查Linux电脑是否支持AVX: 以下指令能检查Linux电脑是否支持AVX:
......
...@@ -137,7 +137,7 @@ GPU driver installed before move on. ...@@ -137,7 +137,7 @@ GPU driver installed before move on.
AVX is a kind of CPU instruction can accelerate PaddlePaddle's calculations. AVX is a kind of CPU instruction can accelerate PaddlePaddle's calculations.
The latest PaddlePaddle Docker image turns AVX on by default, so, if your The latest PaddlePaddle Docker image turns AVX on by default, so, if your
computer doesn't support AVX, you'll probably need to computer doesn't support AVX, you'll probably need to
`build <./build_from_source_en.rst>`_ with :code:`WITH_AVX=OFF`. `build <./build_from_source_en.html>`_ with :code:`WITH_AVX=OFF`.
The following command will tell you whether your computer supports AVX. The following command will tell you whether your computer supports AVX.
......
...@@ -53,7 +53,7 @@ Kernel实现 | CPU、CUDA共享Kernel实现在`.h`文件中,否则,CPU ...@@ -53,7 +53,7 @@ Kernel实现 | CPU、CUDA共享Kernel实现在`.h`文件中,否则,CPU
```cpp ```cpp
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), 2D tensor of size (M x K)"); AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)"); AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
...@@ -82,7 +82,7 @@ The equation is: Out = X * Y ...@@ -82,7 +82,7 @@ The equation is: Out = X * Y
template <typename AttrType> template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient(); AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); AddOutput("Out", "The output tensor of scale operator.").NotInGradient();
......
...@@ -50,7 +50,7 @@ First, define `ProtoMaker` to describe the Operator's input, output, and additio ...@@ -50,7 +50,7 @@ First, define `ProtoMaker` to describe the Operator's input, output, and additio
```cpp ```cpp
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), 2D tensor of size (M x K)"); AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)"); AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
...@@ -79,7 +79,7 @@ An additional example [`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/de ...@@ -79,7 +79,7 @@ An additional example [`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/de
template <typename AttrType> template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient(); AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); AddOutput("Out", "The output tensor of scale operator.").NotInGradient();
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
前一篇文章介绍了如何在Kubernetes集群上启动一个单机PaddlePaddle训练作业 (Job)。在这篇文章里,我们介绍如何在Kubernetes集群上进行分布式PaddlePaddle训练作业。关于PaddlePaddle的分布式训练,文章 [Cluster Training](http://www.paddlepaddle.org/docs/develop/documentation/zh/howto/usage/cluster/cluster_train_cn.html)介绍了一种通过SSH远程分发任务,进行分布式训练的方法,与此不同的是,本文将介绍在Kubernetes容器管理平台上快速构建PaddlePaddle容器集群,进行分布式训练的方案。 前一篇文章介绍了如何在Kubernetes集群上启动一个单机PaddlePaddle训练作业 (Job)。在这篇文章里,我们介绍如何在Kubernetes集群上进行分布式PaddlePaddle训练作业。关于PaddlePaddle的分布式训练,文章 [Cluster Training](http://www.paddlepaddle.org/docs/develop/documentation/zh/howto/usage/cluster/cluster_train_cn.html)介绍了一种通过SSH远程分发任务,进行分布式训练的方法,与此不同的是,本文将介绍在Kubernetes容器管理平台上快速构建PaddlePaddle容器集群,进行分布式训练的方案。
有关Kubernetes相关概念以及如何搭建和配置Kubernetes集群,可以参考[k8s_basis](./k8s_basis_cn.md)
## 整体方案 ## 整体方案
在训练之前,用户将配置与训练数据切分好放在分布式文件系统预先分配好的目录中(不同的分布式文件系统,需要使用其制定的方式挂载后并导入数据),训练时,程序从此目录拷贝文件到容器内进行训练,将结果保存到此目录里。整体的结构图如下: 在训练之前,用户将配置与训练数据切分好放在分布式文件系统预先分配好的目录中(不同的分布式文件系统,需要使用其制定的方式挂载后并导入数据),训练时,程序从此目录拷贝文件到容器内进行训练,将结果保存到此目录里。整体的结构图如下:
......
...@@ -19,42 +19,42 @@ limitations under the License. */ ...@@ -19,42 +19,42 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: { case proto::AttrType::BOOLEAN: {
return attr_desc.b(); return attr_desc.b();
} }
case framework::AttrType::INT: { case proto::AttrType::INT: {
return attr_desc.i(); return attr_desc.i();
} }
case framework::AttrType::FLOAT: { case proto::AttrType::FLOAT: {
return attr_desc.f(); return attr_desc.f();
} }
case framework::AttrType::STRING: { case proto::AttrType::STRING: {
return attr_desc.s(); return attr_desc.s();
} }
case framework::AttrType::BOOLEANS: { case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size()); std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) { for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i); val[i] = attr_desc.bools(i);
} }
return val; return val;
} }
case framework::AttrType::INTS: { case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size()); std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) { for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i); val[i] = attr_desc.ints(i);
} }
return val; return val;
} }
case framework::AttrType::FLOATS: { case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size()); std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) { for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i); val[i] = attr_desc.floats(i);
} }
return val; return val;
} }
case framework::AttrType::STRINGS: { case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size()); std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) { for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i); val[i] = attr_desc.strings(i);
......
...@@ -27,12 +27,12 @@ limitations under the License. */ ...@@ -27,12 +27,12 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T> template <typename T>
inline AttrType AttrTypeID() { inline proto::AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
return static_cast<AttrType>(tmp.which() - 1); return static_cast<proto::AttrType>(tmp.which() - 1);
} }
Attribute GetAttrValue(const OpDesc::Attr& attr_desc); Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class AttrReader { class AttrReader {
public: public:
......
...@@ -341,7 +341,7 @@ static void CreateGradVarInBlock( ...@@ -341,7 +341,7 @@ static void CreateGradVarInBlock(
auto* param = block_desc->FindVarRecursive(pname); auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg); auto* grad = block_desc->FindVar(arg);
if (param == nullptr) { if (param == nullptr) {
grad->SetDataType(DataType::FP32); grad->SetDataType(proto::DataType::FP32);
} else { } else {
grad->SetDataType(param->GetDataType()); grad->SetDataType(param->GetDataType());
} }
......
...@@ -166,7 +166,7 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { ...@@ -166,7 +166,7 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
class SumOpMaker : public framework::OpProtoAndCheckerMaker { class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.").AsDuplicable(); AddInput("X", "the input tensors of sum operator.").AsDuplicable();
AddOutput("Out", "the output tensor of sum operator."); AddOutput("Out", "the output tensor of sum operator.");
......
...@@ -128,22 +128,22 @@ BlockDescBind *BlockDescBind::ParentBlock() const { ...@@ -128,22 +128,22 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx())); return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
} }
BlockDesc *BlockDescBind::Proto() { proto::BlockDesc *BlockDescBind::Proto() {
Flush(); Flush();
return desc_; return desc_;
} }
BlockDescBind::BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind::BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) { : prog_(prog), desc_(desc), need_update_(false) {
for (const VarDesc &var_desc : desc_->vars()) { for (const proto::VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDescBind(var_desc)); vars_[var_desc.name()].reset(new VarDescBind(var_desc));
} }
for (const OpDesc &op_desc : desc_->ops()) { for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDescBind(op_desc, prog)); ops_.emplace_back(new OpDescBind(op_desc, prog));
} }
} }
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc, BlockDescBind::BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc,
ProgramDescBind *prog) ProgramDescBind *prog)
: prog_(prog), desc_(desc) { : prog_(prog), desc_(desc) {
need_update_ = true; need_update_ = true;
......
...@@ -36,9 +36,9 @@ class ProgramDescBind; ...@@ -36,9 +36,9 @@ class ProgramDescBind;
class BlockDescBind { class BlockDescBind {
public: public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc); BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc);
BlockDescBind(const BlockDescBind &other, BlockDesc *desc, BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc,
ProgramDescBind *prog); ProgramDescBind *prog);
~BlockDescBind() { ~BlockDescBind() {
...@@ -88,7 +88,7 @@ class BlockDescBind { ...@@ -88,7 +88,7 @@ class BlockDescBind {
void Flush(); void Flush();
BlockDesc *Proto(); proto::BlockDesc *Proto();
ProgramDescBind *Program() { return this->prog_; } ProgramDescBind *Program() { return this->prog_; }
...@@ -97,8 +97,8 @@ class BlockDescBind { ...@@ -97,8 +97,8 @@ class BlockDescBind {
void ClearPBVars(); void ClearPBVars();
private: private:
ProgramDescBind *prog_; // not_own ProgramDescBind *prog_; // not_own
BlockDesc *desc_; // not_own proto::BlockDesc *desc_; // not_own
bool need_update_; bool need_update_;
std::deque<std::unique_ptr<OpDescBind>> ops_; std::deque<std::unique_ptr<OpDescBind>> ops_;
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline DataType ToDataType(std::type_index type) { inline proto::DataType ToDataType(std::type_index type) {
using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) { if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32; return DataType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) { } else if (typeid(double).hash_code() == type.hash_code()) {
...@@ -36,7 +37,8 @@ inline DataType ToDataType(std::type_index type) { ...@@ -36,7 +37,8 @@ inline DataType ToDataType(std::type_index type) {
} }
} }
inline std::type_index ToTypeIndex(DataType type) { inline std::type_index ToTypeIndex(proto::DataType type) {
using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP32: case DataType::FP32:
return typeid(float); return typeid(float);
...@@ -54,7 +56,8 @@ inline std::type_index ToTypeIndex(DataType type) { ...@@ -54,7 +56,8 @@ inline std::type_index ToTypeIndex(DataType type) {
} }
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(DataType type, Visitor visitor) { inline void VisitDataType(proto::DataType type, Visitor visitor) {
using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP32: case DataType::FP32:
visitor.template operator()<float>(); visitor.template operator()<float>();
......
...@@ -90,7 +90,7 @@ struct OpInfoFiller<T, kOperator> { ...@@ -90,7 +90,7 @@ struct OpInfoFiller<T, kOperator> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new OpProto; info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker(); info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_); auto maker = T(info->proto_, info->checker_);
maker.Validate(); maker.Validate();
......
...@@ -41,20 +41,20 @@ Executor::Executor(const std::vector<platform::Place>& places) { ...@@ -41,20 +41,20 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_.swap(borrowed_contexts); device_contexts_.swap(borrowed_contexts);
} }
static void CreateTensor(Variable* var, VarDesc::VarType var_type) { static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == VarDesc::LOD_TENSOR) { if (var_type == proto::VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == VarDesc::SELECTED_ROWS) { } else if (var_type == proto::VarDesc::SELECTED_ROWS) {
var->GetMutable<SelectedRows>(); var->GetMutable<SelectedRows>();
} else if (var_type == VarDesc::FEED_MINIBATCH) { } else if (var_type == proto::VarDesc::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::FETCH_LIST) { } else if (var_type == proto::VarDesc::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::STEP_SCOPES) { } else if (var_type == proto::VarDesc::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>(); var->GetMutable<std::vector<framework::Scope>>();
} else if (var_type == VarDesc::LOD_RANK_TABLE) { } else if (var_type == proto::VarDesc::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == VarDesc::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>(); var->GetMutable<LoDTensorArray>();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
syntax = "proto2"; syntax = "proto2";
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
package paddle.framework; package paddle.framework.proto;
enum AttrType { enum AttrType {
INT = 0; INT = 0;
......
...@@ -46,4 +46,13 @@ void LoDRankTable::Reset(const LoD& lod, size_t level) { ...@@ -46,4 +46,13 @@ void LoDRankTable::Reset(const LoD& lod, size_t level) {
} }
} // namespace framework } // namespace framework
std::ostream& operator<<(std::ostream& out,
const framework::LoDRankTable& table) {
out << "NumOfSequence " << table.items().size() << "\n";
for (auto& each_item : table.items()) {
out << "\tSeq #" << each_item.index << ", Len=" << each_item.length << "\n";
}
return out;
}
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <iosfwd>
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
namespace paddle { namespace paddle {
...@@ -52,4 +53,8 @@ class LoDRankTable { ...@@ -52,4 +53,8 @@ class LoDRankTable {
}; };
} // namespace framework } // namespace framework
std::ostream& operator<<(std::ostream& out,
const framework::LoDRankTable& table);
} // namespace paddle } // namespace paddle
...@@ -197,7 +197,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor, ...@@ -197,7 +197,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
{ // the 2nd field, tensor description { // the 2nd field, tensor description
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
framework::TensorDesc desc; proto::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto *pb_dims = desc.mutable_dims(); auto *pb_dims = desc.mutable_dims();
...@@ -262,7 +262,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) { ...@@ -262,7 +262,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
uint32_t version; uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version)); is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
framework::TensorDesc desc; proto::TensorDesc desc;
{ // int32_t size { // int32_t size
// proto buffer // proto buffer
int32_t size; int32_t size;
...@@ -281,16 +281,16 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) { ...@@ -281,16 +281,16 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
void *buf; void *buf;
platform::Place cpu = platform::CPUPlace(); platform::Place cpu = platform::CPUPlace();
switch (desc.data_type()) { switch (desc.data_type()) {
case framework::FP32: case proto::FP32:
buf = tensor->mutable_data<float>(cpu); buf = tensor->mutable_data<float>(cpu);
break; break;
case framework::FP64: case proto::FP64:
buf = tensor->mutable_data<double>(cpu); buf = tensor->mutable_data<double>(cpu);
break; break;
case framework::INT32: case proto::INT32:
buf = tensor->mutable_data<int>(cpu); buf = tensor->mutable_data<int>(cpu);
break; break;
case framework::INT64: case proto::INT64:
buf = tensor->mutable_data<int64_t>(cpu); buf = tensor->mutable_data<int64_t>(cpu);
break; break;
default: default:
......
...@@ -184,6 +184,18 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level, ...@@ -184,6 +184,18 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
return tensor; return tensor;
} }
// Get the absolute offset of a lod[start_level][start_idx:end_idx] and
// relative length of details for every levels(i.e., [start_level: ]).
//
// For example,
// lod = [[0, 3, 4, 8], [0, 9, 10, 11, 13, 17, 19, 22, 24]]
// start_level = 0
// start_idx = 1
// end_idx = 3
//
// Returns:
// LoD = [[1, 4], [2, 4, 2, 3, 2]]
// pair<size_t, size_t> = {11, 24}
std::pair<LoD, std::pair<size_t, size_t>> GetSubLoDAndAbsoluteOffset( std::pair<LoD, std::pair<size_t, size_t>> GetSubLoDAndAbsoluteOffset(
const LoD& lod, size_t start_idx, size_t end_idx, size_t start_level); const LoD& lod, size_t start_idx, size_t end_idx, size_t start_level);
......
...@@ -58,11 +58,11 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -58,11 +58,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
PADDLE_ENFORCE_LT(j, Outputs(out).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != VarDesc::LOD_TENSOR) { if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) {
VLOG(3) << "input " << in << " is not LodTensor"; VLOG(3) << "input " << in << " is not LodTensor";
return; return;
} }
PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR, PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j, "The %d-th output of Output(%s) must be LoDTensor.", j,
out); out);
out_var->SetLoDLevel(in_var->GetLodLevel()); out_var->SetLoDLevel(in_var->GetLodLevel());
...@@ -70,7 +70,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -70,7 +70,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override; bool IsRuntime() const override;
protected: protected:
VarDesc::VarType GetVarType(const std::string &name) const override; proto::VarDesc::VarType GetVarType(const std::string &name) const override;
DDim GetDim(const std::string &name) const override; DDim GetDim(const std::string &name) const override;
...@@ -90,12 +90,12 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, ...@@ -90,12 +90,12 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
need_update_ = true; need_update_ = true;
} }
OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog) OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog)
: desc_(desc), need_update_(false) { : desc_(desc), need_update_(false) {
// restore inputs_ // restore inputs_
int input_size = desc_.inputs_size(); int input_size = desc_.inputs_size();
for (int i = 0; i < input_size; ++i) { for (int i = 0; i < input_size; ++i) {
const OpDesc::Var &var = desc_.inputs(i); const proto::OpDesc::Var &var = desc_.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()]; std::vector<std::string> &args = inputs_[var.parameter()];
int argu_size = var.arguments_size(); int argu_size = var.arguments_size();
args.reserve(argu_size); args.reserve(argu_size);
...@@ -106,7 +106,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog) ...@@ -106,7 +106,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
// restore outputs_ // restore outputs_
int output_size = desc_.outputs_size(); int output_size = desc_.outputs_size();
for (int i = 0; i < output_size; ++i) { for (int i = 0; i < output_size; ++i) {
const OpDesc::Var &var = desc_.outputs(i); const proto::OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()]; std::vector<std::string> &args = outputs_[var.parameter()];
int argu_size = var.arguments_size(); int argu_size = var.arguments_size();
args.reserve(argu_size); args.reserve(argu_size);
...@@ -115,9 +115,9 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog) ...@@ -115,9 +115,9 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
} }
} }
// restore attrs_ // restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) { for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name(); std::string attr_name = attr.name();
if (attr.type() != AttrType::BLOCK) { if (attr.type() != proto::AttrType::BLOCK) {
attrs_[attr_name] = GetAttrValue(attr); attrs_[attr_name] = GetAttrValue(attr);
} else { } else {
auto bid = attr.block_idx(); auto bid = attr.block_idx();
...@@ -126,7 +126,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog) ...@@ -126,7 +126,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
} }
} }
OpDesc *OpDescBind::Proto() { proto::OpDesc *OpDescBind::Proto() {
Flush(); Flush();
return &desc_; return &desc_;
} }
...@@ -175,10 +175,10 @@ void OpDescBind::SetOutput(const std::string &param_name, ...@@ -175,10 +175,10 @@ void OpDescBind::SetOutput(const std::string &param_name,
this->outputs_[param_name] = args; this->outputs_[param_name] = args;
} }
AttrType OpDescBind::GetAttrType(const std::string &name) const { proto::AttrType OpDescBind::GetAttrType(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return static_cast<AttrType>(it->second.which() - 1); return static_cast<proto::AttrType>(it->second.which() - 1);
} }
std::vector<std::string> OpDescBind::AttrNames() const { std::vector<std::string> OpDescBind::AttrNames() const {
...@@ -253,8 +253,8 @@ void OpDescBind::RenameInput(const std::string &old_name, ...@@ -253,8 +253,8 @@ void OpDescBind::RenameInput(const std::string &old_name,
} }
struct SetAttrDescVisitor : public boost::static_visitor<void> { struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_; mutable proto::OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); } void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); } void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); } void operator()(const std::string &v) const { attr_->set_s(v); }
...@@ -272,7 +272,9 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -272,7 +272,9 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(const std::vector<bool> &v) const { void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools()); VectorToRepeated(v, attr_->mutable_bools());
} }
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); } void operator()(proto::BlockDesc *desc) const {
attr_->set_block_idx(desc->idx());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
}; };
...@@ -297,7 +299,7 @@ void OpDescBind::Flush() { ...@@ -297,7 +299,7 @@ void OpDescBind::Flush() {
auto *attr_desc = desc_.add_attrs(); auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first); attr_desc->set_name(attr.first);
attr_desc->set_type( attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1)); static_cast<proto::AttrType>(attr.second.which() - 1));
SetAttrDescVisitor visitor(attr_desc); SetAttrDescVisitor visitor(attr_desc);
boost::apply_visitor(visitor, attr.second); boost::apply_visitor(visitor, attr.second);
} }
...@@ -375,7 +377,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const { ...@@ -375,7 +377,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
for (auto &out_pair : this->outputs_) { for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) { for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name) block->FindRecursiveOrCreateVar(out_var_name)
->SetType(VarDesc::LOD_TENSOR); ->SetType(proto::VarDesc::LOD_TENSOR);
} }
} }
} }
...@@ -484,7 +486,7 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name, ...@@ -484,7 +486,7 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
} }
bool CompileTimeInferShapeContext::IsRuntime() const { return false; } bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
VarDesc::VarType CompileTimeInferShapeContext::GetVarType( proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
const std::string &name) const { const std::string &name) const {
return block_.FindVarRecursive(name)->GetType(); return block_.FindVarRecursive(name)->GetType();
} }
......
...@@ -33,9 +33,9 @@ class OpDescBind { ...@@ -33,9 +33,9 @@ class OpDescBind {
OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs); const VariableNameMap &outputs, const AttributeMap &attrs);
OpDescBind(const OpDesc &desc, ProgramDescBind *prog); OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog);
OpDesc *Proto(); proto::OpDesc *Proto();
std::string Type() const { return desc_.type(); } std::string Type() const { return desc_.type(); }
...@@ -59,7 +59,7 @@ class OpDescBind { ...@@ -59,7 +59,7 @@ class OpDescBind {
return attrs_.find(name) != attrs_.end(); return attrs_.find(name) != attrs_.end();
} }
AttrType GetAttrType(const std::string &name) const; proto::AttrType GetAttrType(const std::string &name) const;
std::vector<std::string> AttrNames() const; std::vector<std::string> AttrNames() const;
...@@ -126,7 +126,7 @@ class OpDescBind { ...@@ -126,7 +126,7 @@ class OpDescBind {
return ret_val; return ret_val;
} }
OpDesc desc_; proto::OpDesc desc_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
......
...@@ -34,7 +34,7 @@ class InferShapeBase { ...@@ -34,7 +34,7 @@ class InferShapeBase {
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
GradOpMakerFN grad_op_maker_; GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr}; proto::OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_; InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_; InferShapeFN infer_shape_;
...@@ -43,7 +43,7 @@ struct OpInfo { ...@@ -43,7 +43,7 @@ struct OpInfo {
return proto_ != nullptr && checker_ != nullptr; return proto_ != nullptr && checker_ != nullptr;
} }
const OpProto& Proto() const { const proto::OpProto& Proto() const {
PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered"); PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered");
PADDLE_ENFORCE(proto_->IsInitialized(), PADDLE_ENFORCE(proto_->IsInitialized(),
"Operator Proto must be initialized in op info"); "Operator Proto must be initialized in op info");
......
...@@ -22,6 +22,8 @@ namespace framework { ...@@ -22,6 +22,8 @@ namespace framework {
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker { class OpProtoAndCheckerMaker {
public: public:
using OpProto = proto::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {} : proto_(proto), op_checker_(op_checker) {}
...@@ -80,7 +82,7 @@ class OpProtoAndCheckerMaker { ...@@ -80,7 +82,7 @@ class OpProtoAndCheckerMaker {
class NOPMaker : public OpProtoAndCheckerMaker { class NOPMaker : public OpProtoAndCheckerMaker {
public: public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {} : OpProtoAndCheckerMaker(proto, op_checker) {}
}; };
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public: public:
TestAttrProtoMaker(paddle::framework::OpProto* proto, TestAttrProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker) paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op"); AddAttr<float>("scale", "scale of test op");
...@@ -27,7 +27,7 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { ...@@ -27,7 +27,7 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
}; };
TEST(ProtoMaker, DuplicatedAttr) { TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::OpProto op_proto; paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
...@@ -35,7 +35,7 @@ TEST(ProtoMaker, DuplicatedAttr) { ...@@ -35,7 +35,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public: public:
TestInOutProtoMaker(paddle::framework::OpProto* proto, TestInOutProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker) paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op"); AddInput("input", "input of test op");
...@@ -44,7 +44,7 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { ...@@ -44,7 +44,7 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
}; };
TEST(ProtoMaker, DuplicatedInOut) { TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::OpProto op_proto; paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
......
...@@ -31,7 +31,8 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp( ...@@ -31,7 +31,8 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
} }
static VariableNameMap ConvertOpDescVarsToVarNameMap( static VariableNameMap ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars) { const google::protobuf::RepeatedPtrField<proto::OpDesc::Var>&
op_desc_vars) {
VariableNameMap ret_val; VariableNameMap ret_val;
for (auto& var : op_desc_vars) { for (auto& var : op_desc_vars) {
auto& var_names = ret_val[var.parameter()]; auto& var_names = ret_val[var.parameter()];
...@@ -43,7 +44,8 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap( ...@@ -43,7 +44,8 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
return ret_val; return ret_val;
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) { std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const proto::OpDesc& op_desc) {
VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be" VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDescBind& op_desc) " "used in unit tests. Use CreateOp(const OpDescBind& op_desc) "
"instead."; "instead.";
......
...@@ -77,7 +77,7 @@ class OpRegistry { ...@@ -77,7 +77,7 @@ class OpRegistry {
const VariableNameMap& outputs, const VariableNameMap& outputs,
AttributeMap attrs); AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
}; };
......
...@@ -51,7 +51,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -51,7 +51,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static void BuildVar(const std::string& param_name, static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments, std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) { paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name); var->set_parameter(param_name);
for (auto& arg_name : arguments) { for (auto& arg_name : arguments) {
var->add_arguments(arg_name); var->add_arguments(arg_name);
...@@ -63,7 +63,7 @@ REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp, ...@@ -63,7 +63,7 @@ REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker); paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) { TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs()); BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs()); BuildVar("output", {"bb"}, op_desc.add_outputs());
...@@ -71,7 +71,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -71,7 +71,7 @@ TEST(OpRegistry, CreateOp) {
float scale = 3.3; float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(scale); attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
...@@ -83,14 +83,14 @@ TEST(OpRegistry, CreateOp) { ...@@ -83,14 +83,14 @@ TEST(OpRegistry, CreateOp) {
} }
TEST(OpRegistry, IllegalAttr) { TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs()); BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs()); BuildVar("output", {"bb"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(-2.0); attr->set_f(-2.0);
bool caught = false; bool caught = false;
...@@ -108,7 +108,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -108,7 +108,7 @@ TEST(OpRegistry, IllegalAttr) {
} }
TEST(OpRegistry, DefaultValue) { TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs()); BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs()); BuildVar("output", {"bb"}, op_desc.add_outputs());
...@@ -123,7 +123,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -123,7 +123,7 @@ TEST(OpRegistry, DefaultValue) {
} }
TEST(OpRegistry, CustomChecker) { TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("my_test_op"); op_desc.set_type("my_test_op");
BuildVar("input", {"ii"}, op_desc.add_inputs()); BuildVar("input", {"ii"}, op_desc.add_inputs());
BuildVar("output", {"oo"}, op_desc.add_outputs()); BuildVar("output", {"oo"}, op_desc.add_outputs());
...@@ -145,7 +145,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -145,7 +145,7 @@ TEST(OpRegistry, CustomChecker) {
// set 'test_attr' set to an illegal value // set 'test_attr' set to an illegal value
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(3); attr->set_i(3);
caught = false; caught = false;
try { try {
...@@ -164,7 +164,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -164,7 +164,7 @@ TEST(OpRegistry, CustomChecker) {
op_desc.mutable_attrs()->Clear(); op_desc.mutable_attrs()->Clear();
attr = op_desc.mutable_attrs()->Add(); attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(4); attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
......
...@@ -377,7 +377,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -377,7 +377,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
} }
VarDesc::VarType GetVarType(const std::string& name) const override { proto::VarDesc::VarType GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name); auto* var = scope_.FindVar(name);
return ToVarType(var->Type()); return ToVarType(var->Type());
} }
...@@ -417,7 +417,7 @@ OpKernelType OperatorWithKernel::GetKernelType( ...@@ -417,7 +417,7 @@ OpKernelType OperatorWithKernel::GetKernelType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace()); return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
} }
DataType OperatorWithKernel::IndicateDataType( proto::DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
auto& scope = ctx.scope(); auto& scope = ctx.scope();
int data_type = -1; int data_type = -1;
...@@ -443,7 +443,7 @@ DataType OperatorWithKernel::IndicateDataType( ...@@ -443,7 +443,7 @@ DataType OperatorWithKernel::IndicateDataType(
} }
} }
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input"); PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type); return static_cast<proto::DataType>(data_type);
} }
} // namespace framework } // namespace framework
......
...@@ -358,12 +358,13 @@ struct OpKernelType { ...@@ -358,12 +358,13 @@ struct OpKernelType {
}; };
platform::Place place_; platform::Place place_;
DataType data_type_; proto::DataType data_type_;
OpKernelType(DataType data_type, platform::Place place) OpKernelType(proto::DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {} : place_(place), data_type_(data_type) {}
OpKernelType(DataType data_type, const platform::DeviceContext& dev_ctx) OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {} : place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelType& o) const { bool operator==(const OpKernelType& o) const {
...@@ -409,7 +410,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -409,7 +410,7 @@ class OperatorWithKernel : public OperatorBase {
private: private:
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. Defaultly all input data must be
// same. // same.
DataType IndicateDataType(const ExecutionContext& ctx) const; proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
}; };
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key); std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);
......
...@@ -58,7 +58,7 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -58,7 +58,7 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static void BuildVar(const std::string& param_name, static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments, std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) { paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name); var->set_parameter(param_name);
for (auto& arg_name : arguments) { for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name; *var->mutable_arguments()->Add() = arg_name;
...@@ -70,14 +70,14 @@ REGISTER_OP_WITHOUT_GRADIENT( ...@@ -70,14 +70,14 @@ REGISTER_OP_WITHOUT_GRADIENT(
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
TEST(OperatorBase, all) { TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator"); op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs()); BuildVar("input", {"IN1"}, op_desc.add_inputs());
BuildVar("output", {"OUT1"}, op_desc.add_outputs()); BuildVar("output", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext device_context; paddle::platform::CPUDeviceContext device_context;
...@@ -115,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -115,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetKernelType(const ExecutionContext& ctx) const override { OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(DataType::FP32, ctx.GetPlace()); return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
} }
}; };
...@@ -195,14 +195,14 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel, ...@@ -195,14 +195,14 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
// test with single input // test with single input
TEST(OpKernel, all) { TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_kernel"); op_desc.set_type("op_with_kernel");
BuildVar("x", {"IN1"}, op_desc.add_inputs()); BuildVar("x", {"IN1"}, op_desc.add_inputs());
BuildVar("y", {"OUT1"}, op_desc.add_outputs()); BuildVar("y", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
...@@ -224,7 +224,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, ...@@ -224,7 +224,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
TEST(OpKernel, multi_inputs) { TEST(OpKernel, multi_inputs) {
using namespace paddle::framework; using namespace paddle::framework;
OpDesc op_desc; proto::OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel"); op_desc.set_type("op_multi_inputs_with_kernel");
BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs()); BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
BuildVar("k", {"k0"}, op_desc.add_inputs()); BuildVar("k", {"k0"}, op_desc.add_inputs());
...@@ -232,7 +232,7 @@ TEST(OpKernel, multi_inputs) { ...@@ -232,7 +232,7 @@ TEST(OpKernel, multi_inputs) {
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
......
...@@ -26,7 +26,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { ...@@ -26,7 +26,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
return blocks_.back().get(); return blocks_.back().get();
} }
ProgramDesc *ProgramDescBind::Proto() { proto::ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Flush(); block->Flush();
} }
...@@ -49,7 +49,7 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) { ...@@ -49,7 +49,7 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
} }
} }
ProgramDescBind::ProgramDescBind(const ProgramDesc &desc) { ProgramDescBind::ProgramDescBind(const proto::ProgramDesc &desc) {
desc_ = desc; desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc)); blocks_.emplace_back(new BlockDescBind(this, &block_desc));
......
...@@ -29,7 +29,7 @@ class ProgramDescBind { ...@@ -29,7 +29,7 @@ class ProgramDescBind {
public: public:
ProgramDescBind(); ProgramDescBind();
explicit ProgramDescBind(const ProgramDesc &desc); explicit ProgramDescBind(const proto::ProgramDesc &desc);
ProgramDescBind(const ProgramDescBind &o); ProgramDescBind(const ProgramDescBind &o);
...@@ -43,10 +43,10 @@ class ProgramDescBind { ...@@ -43,10 +43,10 @@ class ProgramDescBind {
size_t Size() const { return blocks_.size(); } size_t Size() const { return blocks_.size(); }
ProgramDesc *Proto(); proto::ProgramDesc *Proto();
private: private:
ProgramDesc desc_; proto::ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_; std::vector<std::unique_ptr<BlockDescBind>> blocks_;
}; };
......
...@@ -22,15 +22,15 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -22,15 +22,15 @@ TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program; ProgramDescBind program;
auto* global_block = program.MutableBlock(0); auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(FP32); x->SetDataType(proto::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(VarDesc_VarType_LOD_TENSOR); y->SetType(proto::VarDesc_VarType_LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(FP32); y->SetDataType(proto::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
auto* op = global_block->AppendOp(); auto* op = global_block->AppendOp();
...@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) {
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out"); auto* out = global_block->Var("Out");
out->SetType(VarDesc_VarType_LOD_TENSOR); out->SetType(proto::VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
ProgramDescBind program_copy(program); ProgramDescBind program_copy(program);
...@@ -84,15 +84,15 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -84,15 +84,15 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin; ProgramDescBind program_origin;
auto* global_block = program_origin.MutableBlock(0); auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(FP32); x->SetDataType(proto::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(VarDesc_VarType_LOD_TENSOR); y->SetType(proto::VarDesc_VarType_LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(FP32); y->SetDataType(proto::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
auto* op = global_block->AppendOp(); auto* op = global_block->AppendOp();
...@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out"); auto* out = global_block->Var("Out");
out->SetType(VarDesc_VarType_LOD_TENSOR); out->SetType(proto::VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
std::string binary_str; std::string binary_str;
......
...@@ -29,7 +29,7 @@ const std::string kFetchOpType = "fetch"; ...@@ -29,7 +29,7 @@ const std::string kFetchOpType = "fetch";
const std::string kDropOutOpType = "dropout"; const std::string kDropOutOpType = "dropout";
const std::string kBatchNormOpType = "batch_norm"; const std::string kBatchNormOpType = "batch_norm";
bool HasDependentVar(const OpDesc& op_desc, bool HasDependentVar(const proto::OpDesc& op_desc,
const std::set<std::string>& dependent_vars) { const std::set<std::string>& dependent_vars) {
for (auto& var : op_desc.outputs()) { for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) { for (auto& argu : var.arguments()) {
...@@ -41,14 +41,15 @@ bool HasDependentVar(const OpDesc& op_desc, ...@@ -41,14 +41,15 @@ bool HasDependentVar(const OpDesc& op_desc,
return false; return false;
} }
bool IsTarget(const OpDesc& op_desc) { bool IsTarget(const proto::OpDesc& op_desc) {
if (op_desc.has_is_target()) { if (op_desc.has_is_target()) {
return op_desc.is_target(); return op_desc.is_target();
} }
return false; return false;
} }
void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) { void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
int block_id) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op // - will change to use multiple blocks for RNN op and Cond Op
...@@ -104,12 +105,12 @@ void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) { ...@@ -104,12 +105,12 @@ void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) {
} }
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const ProgramDesc& input, ProgramDesc* output) { void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
prune_impl(input, output, 0); prune_impl(input, output, 0);
} }
void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output, void inference_optimize_impl(const proto::ProgramDesc& input,
int block_id) { proto::ProgramDesc* output, int block_id) {
*output = input; *output = input;
auto* op_field = output->mutable_blocks(block_id)->mutable_ops(); auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
for (auto& op_desc : *op_field) { for (auto& op_desc : *op_field) {
...@@ -125,7 +126,8 @@ void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output, ...@@ -125,7 +126,8 @@ void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output,
} }
} }
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output) { void InferenceOptimize(const proto::ProgramDesc& input,
proto::ProgramDesc* output) {
inference_optimize_impl(input, output, 0); inference_optimize_impl(input, output, 0);
} }
......
...@@ -20,9 +20,10 @@ limitations under the License. */ ...@@ -20,9 +20,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc* output); void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output);
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output); void InferenceOptimize(const proto::ProgramDesc& input,
proto::ProgramDesc* output);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, ...@@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) { for (auto kv : outputs) {
for (auto v : kv.second) { for (auto v : kv.second) {
auto var = block->Var(v); auto var = block->Var(v);
var->SetDataType(paddle::framework::DataType::FP32); var->SetDataType(paddle::framework::proto::DataType::FP32);
} }
} }
...@@ -57,14 +57,14 @@ TEST(Prune, one_operator) { ...@@ -57,14 +57,14 @@ TEST(Prune, one_operator) {
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{}, AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{},
block); block);
f::ProgramDesc *pdesc = program.Proto(); f::proto::ProgramDesc *pdesc = program.Proto();
f::ProgramDesc pruned; f::proto::ProgramDesc pruned;
Prune(*pdesc, &pruned); f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
Prune(*pdesc, &pruned); f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
} }
...@@ -81,12 +81,12 @@ TEST(Prune, forward) { ...@@ -81,12 +81,12 @@ TEST(Prune, forward) {
AddOp("one_one", {{"input", {"d"}}}, {{"output", {"e"}}}, f::AttributeMap{}, AddOp("one_one", {{"input", {"d"}}}, {{"output", {"e"}}}, f::AttributeMap{},
block); block);
f::ProgramDesc *pdesc = program.Proto(); f::proto::ProgramDesc *pdesc = program.Proto();
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) { for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::ProgramDesc pruned; f::proto::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
Prune(*pdesc, &pruned); f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
} }
} }
...@@ -104,11 +104,11 @@ TEST(Prune, multi_input_op) { ...@@ -104,11 +104,11 @@ TEST(Prune, multi_input_op) {
AddOp("three_one", {{"input", {"b0", "b1", "b2"}}}, {{"output", {"c"}}}, AddOp("three_one", {{"input", {"b0", "b1", "b2"}}}, {{"output", {"c"}}},
f::AttributeMap{}, block); f::AttributeMap{}, block);
f::ProgramDesc *pdesc = program.Proto(); f::proto::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::ProgramDesc pruned; f::proto::ProgramDesc pruned;
Prune(*pdesc, &pruned); f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
} }
...@@ -123,11 +123,11 @@ TEST(Prune, multi_output_op) { ...@@ -123,11 +123,11 @@ TEST(Prune, multi_output_op) {
AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, f::AttributeMap{}, AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, f::AttributeMap{},
block); block);
f::ProgramDesc *pdesc = program.Proto(); f::proto::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned; f::proto::ProgramDesc pruned;
Prune(*pdesc, &pruned); f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
} }
...@@ -142,11 +142,11 @@ TEST(Prune, multi_target) { ...@@ -142,11 +142,11 @@ TEST(Prune, multi_target) {
AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, f::AttributeMap{}, AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, f::AttributeMap{},
block); block);
f::ProgramDesc *pdesc = program.Proto(); f::proto::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true);
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned; f::proto::ProgramDesc pruned;
Prune(*pdesc, &pruned); f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
} }
...@@ -57,17 +57,17 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names, ...@@ -57,17 +57,17 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
SetDim(names[i], dims[i]); SetDim(names[i], dims[i]);
} }
} }
std::vector<VarDesc::VarType> InferShapeContext::GetInputsVarType( std::vector<proto::VarDesc::VarType> InferShapeContext::GetInputsVarType(
const std::string &name) const { const std::string &name) const {
return GetVarTypes(Inputs(name)); return GetVarTypes(Inputs(name));
} }
std::vector<VarDesc::VarType> InferShapeContext::GetOutputsVarType( std::vector<proto::VarDesc::VarType> InferShapeContext::GetOutputsVarType(
const std::string &name) const { const std::string &name) const {
return GetVarTypes(Outputs(name)); return GetVarTypes(Outputs(name));
} }
std::vector<VarDesc::VarType> InferShapeContext::GetVarTypes( std::vector<proto::VarDesc::VarType> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
std::vector<VarDesc::VarType> retv; std::vector<proto::VarDesc::VarType> retv;
retv.resize(names.size()); retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(), std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this, std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
......
...@@ -27,8 +27,9 @@ class InferShapeContext { ...@@ -27,8 +27,9 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0; virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0;
std::vector<VarDesc::VarType> GetInputsVarType(const std::string &name) const; std::vector<proto::VarDesc::VarType> GetInputsVarType(
std::vector<VarDesc::VarType> GetOutputsVarType( const std::string &name) const;
std::vector<proto::VarDesc::VarType> GetOutputsVarType(
const std::string &name) const; const std::string &name) const;
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
...@@ -65,10 +66,10 @@ class InferShapeContext { ...@@ -65,10 +66,10 @@ class InferShapeContext {
std::vector<framework::DDim> GetDims( std::vector<framework::DDim> GetDims(
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
std::vector<VarDesc::VarType> GetVarTypes( std::vector<proto::VarDesc::VarType> GetVarTypes(
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
virtual VarDesc::VarType GetVarType(const std::string &name) const = 0; virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0;
}; };
} // namespace framework } // namespace framework
......
...@@ -18,15 +18,17 @@ limitations under the License. */ ...@@ -18,15 +18,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
VarDesc::VarType VarDescBind::GetType() const { return desc_.type(); } proto::VarDesc::VarType VarDescBind::GetType() const { return desc_.type(); }
void VarDescBind::SetType(VarDesc::VarType type) { desc_.set_type(type); } void VarDescBind::SetType(proto::VarDesc::VarType type) {
desc_.set_type(type);
}
void VarDescBind::SetShape(const std::vector<int64_t> &dims) { void VarDescBind::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
} }
void VarDescBind::SetDataType(DataType data_type) { void VarDescBind::SetDataType(proto::DataType data_type) {
mutable_tensor_desc()->set_data_type(data_type); mutable_tensor_desc()->set_data_type(data_type);
} }
...@@ -34,14 +36,16 @@ std::vector<int64_t> VarDescBind::Shape() const { ...@@ -34,14 +36,16 @@ std::vector<int64_t> VarDescBind::Shape() const {
return RepeatedToVector(tensor_desc().dims()); return RepeatedToVector(tensor_desc().dims());
} }
DataType VarDescBind::GetDataType() const { return tensor_desc().data_type(); } proto::DataType VarDescBind::GetDataType() const {
return tensor_desc().data_type();
}
void VarDescBind::SetLoDLevel(int32_t lod_level) { void VarDescBind::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) { switch (desc_.type()) {
case VarDesc::LOD_TENSOR: case proto::VarDesc::LOD_TENSOR:
desc_.mutable_lod_tensor()->set_lod_level(lod_level); desc_.mutable_lod_tensor()->set_lod_level(lod_level);
break; break;
case VarDesc::LOD_TENSOR_ARRAY: case proto::VarDesc::LOD_TENSOR_ARRAY:
desc_.mutable_tensor_array()->set_lod_level(lod_level); desc_.mutable_tensor_array()->set_lod_level(lod_level);
break; break;
default: default:
...@@ -52,9 +56,9 @@ void VarDescBind::SetLoDLevel(int32_t lod_level) { ...@@ -52,9 +56,9 @@ void VarDescBind::SetLoDLevel(int32_t lod_level) {
int32_t VarDescBind::GetLodLevel() const { int32_t VarDescBind::GetLodLevel() const {
switch (desc_.type()) { switch (desc_.type()) {
case VarDesc::LOD_TENSOR: case proto::VarDesc::LOD_TENSOR:
return desc_.lod_tensor().lod_level(); return desc_.lod_tensor().lod_level();
case VarDesc::LOD_TENSOR_ARRAY: case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level(); return desc_.tensor_array().lod_level();
default: default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel", PADDLE_THROW("Tensor type=%d does not support LoDLevel",
...@@ -62,29 +66,29 @@ int32_t VarDescBind::GetLodLevel() const { ...@@ -62,29 +66,29 @@ int32_t VarDescBind::GetLodLevel() const {
} }
} }
const TensorDesc &VarDescBind::tensor_desc() const { const proto::TensorDesc &VarDescBind::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type"); PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
switch (desc_.type()) { switch (desc_.type()) {
case VarDesc::SELECTED_ROWS: case proto::VarDesc::SELECTED_ROWS:
return desc_.selected_rows(); return desc_.selected_rows();
case VarDesc::LOD_TENSOR: case proto::VarDesc::LOD_TENSOR:
return desc_.lod_tensor().tensor(); return desc_.lod_tensor().tensor();
case VarDesc::LOD_TENSOR_ARRAY: case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().tensor(); return desc_.tensor_array().tensor();
default: default:
PADDLE_THROW("Unexpected branch."); PADDLE_THROW("Unexpected branch.");
} }
} }
TensorDesc *VarDescBind::mutable_tensor_desc() { proto::TensorDesc *VarDescBind::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), PADDLE_ENFORCE(desc_.has_type(),
"invoke MutableTensorDesc must after set type"); "invoke MutableTensorDesc must after set type");
switch (desc_.type()) { switch (desc_.type()) {
case VarDesc::SELECTED_ROWS: case proto::VarDesc::SELECTED_ROWS:
return desc_.mutable_selected_rows(); return desc_.mutable_selected_rows();
case VarDesc::LOD_TENSOR: case proto::VarDesc::LOD_TENSOR:
return desc_.mutable_lod_tensor()->mutable_tensor(); return desc_.mutable_lod_tensor()->mutable_tensor();
case VarDesc::LOD_TENSOR_ARRAY: case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.mutable_tensor_array()->mutable_tensor(); return desc_.mutable_tensor_array()->mutable_tensor();
default: default:
PADDLE_THROW("Unexpected branch."); PADDLE_THROW("Unexpected branch.");
......
...@@ -57,40 +57,40 @@ class VarDescBind { ...@@ -57,40 +57,40 @@ class VarDescBind {
public: public:
explicit VarDescBind(const std::string &name) { explicit VarDescBind(const std::string &name) {
desc_.set_name(name); desc_.set_name(name);
desc_.set_type(VarDesc::LOD_TENSOR); desc_.set_type(proto::VarDesc::LOD_TENSOR);
} }
explicit VarDescBind(const VarDesc &desc) : desc_(desc) {} explicit VarDescBind(const proto::VarDesc &desc) : desc_(desc) {}
VarDesc *Proto() { return &desc_; } proto::VarDesc *Proto() { return &desc_; }
std::string Name() const { return desc_.name(); } std::string Name() const { return desc_.name(); }
void SetShape(const std::vector<int64_t> &dims); void SetShape(const std::vector<int64_t> &dims);
void SetDataType(DataType data_type); void SetDataType(proto::DataType data_type);
std::vector<int64_t> Shape() const; std::vector<int64_t> Shape() const;
DataType GetDataType() const; proto::DataType GetDataType() const;
void SetLoDLevel(int32_t lod_level); void SetLoDLevel(int32_t lod_level);
int32_t GetLodLevel() const; int32_t GetLodLevel() const;
VarDesc::VarType GetType() const; proto::VarDesc::VarType GetType() const;
void SetType(VarDesc::VarType type); void SetType(proto::VarDesc::VarType type);
bool Persistable() const { return desc_.persistable(); } bool Persistable() const { return desc_.persistable(); }
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
private: private:
const TensorDesc &tensor_desc() const; const proto::TensorDesc &tensor_desc() const;
TensorDesc *mutable_tensor_desc(); proto::TensorDesc *mutable_tensor_desc();
VarDesc desc_; proto::VarDesc desc_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,15 +20,15 @@ ...@@ -20,15 +20,15 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline VarDesc::VarType ToVarType(std::type_index type) { inline proto::VarDesc::VarType ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) { if (type.hash_code() == typeid(LoDTensor).hash_code()) {
return VarDesc_VarType_LOD_TENSOR; return proto::VarDesc_VarType_LOD_TENSOR;
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) { } else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
return VarDesc_VarType_LOD_RANK_TABLE; return proto::VarDesc_VarType_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return VarDesc_VarType_LOD_TENSOR_ARRAY; return proto::VarDesc_VarType_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) { } else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return VarDesc_VarType_SELECTED_ROWS; return proto::VarDesc_VarType_SELECTED_ROWS;
} else { } else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
} }
...@@ -37,16 +37,16 @@ inline VarDesc::VarType ToVarType(std::type_index type) { ...@@ -37,16 +37,16 @@ inline VarDesc::VarType ToVarType(std::type_index type) {
template <typename Visitor> template <typename Visitor>
inline void VisitVarType(const Variable& var, Visitor visitor) { inline void VisitVarType(const Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) { switch (ToVarType(var.Type())) {
case VarDesc_VarType_LOD_TENSOR: case proto::VarDesc_VarType_LOD_TENSOR:
visitor(var.Get<framework::LoDTensor>()); visitor(var.Get<framework::LoDTensor>());
return; return;
case VarDesc_VarType_LOD_RANK_TABLE: case proto::VarDesc_VarType_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>()); visitor(var.Get<LoDRankTable>());
return; return;
case VarDesc_VarType_LOD_TENSOR_ARRAY: case proto::VarDesc_VarType_LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>()); visitor(var.Get<LoDTensorArray>());
return; return;
case VarDesc_VarType_SELECTED_ROWS: case proto::VarDesc_VarType_SELECTED_ROWS:
visitor(var.Get<SelectedRows>()); visitor(var.Get<SelectedRows>());
return; return;
default: default:
......
...@@ -36,14 +36,14 @@ class SumOpVarTypeInference : public VarTypeInference { ...@@ -36,14 +36,14 @@ class SumOpVarTypeInference : public VarTypeInference {
void operator()(const OpDescBind &op_desc, void operator()(const OpDescBind &op_desc,
BlockDescBind *block) const override { BlockDescBind *block) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = op_desc.Input("X");
auto default_var_type = VarDesc::SELECTED_ROWS; auto default_var_type = proto::VarDesc::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) { inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == VarDesc::LOD_TENSOR; return block->Var(name)->GetType() == proto::VarDesc::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
default_var_type = VarDesc::LOD_TENSOR; default_var_type = proto::VarDesc::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
...@@ -68,19 +68,19 @@ TEST(InferVarType, sum_op) { ...@@ -68,19 +68,19 @@ TEST(InferVarType, sum_op) {
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
prog.MutableBlock(0)->Var("test_a")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out"); prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc::SELECTED_ROWS, ASSERT_EQ(proto::VarDesc::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc::LOD_TENSOR, ASSERT_EQ(proto::VarDesc::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
} }
...@@ -91,14 +91,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) { ...@@ -91,14 +91,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"}); op->SetOutput("Out", {"test2_out"});
prog.MutableBlock(0)->Var("test2_a")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_b")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_c")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_out"); prog.MutableBlock(0)->Var("test2_out");
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc_VarType_LOD_TENSOR, ASSERT_EQ(proto::VarDesc_VarType_LOD_TENSOR,
prog.MutableBlock(0)->Var("test2_out")->GetType()); prog.MutableBlock(0)->Var("test2_out")->GetType());
} }
......
...@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AccuracyOpMaker(framework::OpProto *proto, AccuracyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
// TODO(typhoonzero): support both inference value and indices. // TODO(typhoonzero): support both inference value and indices.
AddInput("Out", "The network output of topk (inferences)"); AddInput("Out", "The network output of topk (inferences)");
......
...@@ -38,9 +38,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel { ...@@ -38,9 +38,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SigmoidOpMaker(framework::OpProto *proto, SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sigmoid operator"); AddInput("X", "Input of Sigmoid operator");
AddOutput("Y", "Output of Sigmoid operator"); AddOutput("Y", "Output of Sigmoid operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -54,9 +53,8 @@ $$y = \frac{1}{1 + e^{-x}}$$ ...@@ -54,9 +53,8 @@ $$y = \frac{1}{1 + e^{-x}}$$
class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LogSigmoidOpMaker(framework::OpProto *proto, LogSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LogSigmoid operator"); AddInput("X", "Input of LogSigmoid operator");
AddOutput("Y", "Output of LogSigmoid operator"); AddOutput("Y", "Output of LogSigmoid operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -70,8 +68,8 @@ $$y = \log \frac{1}{1 + e^{-x}}$$ ...@@ -70,8 +68,8 @@ $$y = \log \frac{1}{1 + e^{-x}}$$
class ExpOpMaker : public framework::OpProtoAndCheckerMaker { class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ExpOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Exp operator"); AddInput("X", "Input of Exp operator");
AddOutput("Y", "Output of Exp operator"); AddOutput("Y", "Output of Exp operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -85,8 +83,8 @@ $y = e^x$ ...@@ -85,8 +83,8 @@ $y = e^x$
class ReluOpMaker : public framework::OpProtoAndCheckerMaker { class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator"); AddInput("X", "Input of Relu operator");
AddOutput("Y", "Output of Relu operator"); AddOutput("Y", "Output of Relu operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -100,9 +98,8 @@ $y = \max(x, 0)$ ...@@ -100,9 +98,8 @@ $y = \max(x, 0)$
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LeakyReluOpMaker(framework::OpProto *proto, LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator"); AddInput("X", "Input of LeakyRelu operator");
AddOutput("Y", "Output of LeakyRelu operator"); AddOutput("Y", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f); AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
...@@ -117,9 +114,8 @@ $y = \max(x, \alpha * x)$ ...@@ -117,9 +114,8 @@ $y = \max(x, \alpha * x)$
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftShrinkOpMaker(framework::OpProto *proto, SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softshrink operator"); AddInput("X", "Input of Softshrink operator");
AddOutput("Y", "Output of Softshrink operator"); AddOutput("Y", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f); AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
...@@ -140,8 +136,8 @@ $$ ...@@ -140,8 +136,8 @@ $$
class TanhOpMaker : public framework::OpProtoAndCheckerMaker { class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) TanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Tanh operator"); AddInput("X", "Input of Tanh operator");
AddOutput("Y", "Output of Tanh operator"); AddOutput("Y", "Output of Tanh operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -155,9 +151,8 @@ $$y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ ...@@ -155,9 +151,8 @@ $$y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
TanhShrinkOpMaker(framework::OpProto *proto, TanhShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of TanhShrink operator"); AddInput("X", "Input of TanhShrink operator");
AddOutput("Y", "Output of TanhShrink operator"); AddOutput("Y", "Output of TanhShrink operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -171,9 +166,8 @@ $$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ ...@@ -171,9 +166,8 @@ $$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HardShrinkOpMaker(framework::OpProto *proto, HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator"); AddInput("X", "Input of HardShrink operator");
AddOutput("Y", "Output of HardShrink operator"); AddOutput("Y", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink") AddAttr<float>("threshold", "The value of threshold for HardShrink")
...@@ -195,8 +189,8 @@ $$ ...@@ -195,8 +189,8 @@ $$
class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SqrtOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sqrt operator"); AddInput("X", "Input of Sqrt operator");
AddOutput("Y", "Output of Sqrt operator"); AddOutput("Y", "Output of Sqrt operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -210,8 +204,8 @@ $y = \sqrt{x}$ ...@@ -210,8 +204,8 @@ $y = \sqrt{x}$
class AbsOpMaker : public framework::OpProtoAndCheckerMaker { class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AbsOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) AbsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Abs operator"); AddInput("X", "Input of Abs operator");
AddOutput("Y", "Output of Abs operator"); AddOutput("Y", "Output of Abs operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -225,8 +219,8 @@ $y = |x|$ ...@@ -225,8 +219,8 @@ $y = |x|$
class CeilOpMaker : public framework::OpProtoAndCheckerMaker { class CeilOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CeilOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) CeilOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Ceil operator"); AddInput("X", "Input of Ceil operator");
AddOutput("Y", "Output of Ceil operator"); AddOutput("Y", "Output of Ceil operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -240,8 +234,8 @@ $y = ceil(x)$ ...@@ -240,8 +234,8 @@ $y = ceil(x)$
class FloorOpMaker : public framework::OpProtoAndCheckerMaker { class FloorOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FloorOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) FloorOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Floor operator"); AddInput("X", "Input of Floor operator");
AddOutput("Y", "Output of Floor operator"); AddOutput("Y", "Output of Floor operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -255,8 +249,8 @@ $y = floor(x)$ ...@@ -255,8 +249,8 @@ $y = floor(x)$
class RoundOpMaker : public framework::OpProtoAndCheckerMaker { class RoundOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RoundOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Round operator"); AddInput("X", "Input of Round operator");
AddOutput("Y", "Output of Round operator"); AddOutput("Y", "Output of Round operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -270,9 +264,8 @@ $y = [x]$ ...@@ -270,9 +264,8 @@ $y = [x]$
class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker { class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ReciprocalOpMaker(framework::OpProto *proto, ReciprocalOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Reciprocal operator"); AddInput("X", "Input of Reciprocal operator");
AddOutput("Y", "Output of Reciprocal operator"); AddOutput("Y", "Output of Reciprocal operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -286,8 +279,8 @@ $$y = \frac{1}{x}$$ ...@@ -286,8 +279,8 @@ $$y = \frac{1}{x}$$
class LogOpMaker : public framework::OpProtoAndCheckerMaker { class LogOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LogOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) LogOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Log operator"); AddInput("X", "Input of Log operator");
AddOutput("Y", "Output of Log operator"); AddOutput("Y", "Output of Log operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -303,8 +296,8 @@ Natural logarithm of x. ...@@ -303,8 +296,8 @@ Natural logarithm of x.
class SquareOpMaker : public framework::OpProtoAndCheckerMaker { class SquareOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SquareOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SquareOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Square operator"); AddInput("X", "Input of Square operator");
AddOutput("Y", "Output of Square operator"); AddOutput("Y", "Output of Square operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -318,9 +311,8 @@ $y = x^2$ ...@@ -318,9 +311,8 @@ $y = x^2$
class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftplusOpMaker(framework::OpProto *proto, SoftplusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softplus operator"); AddInput("X", "Input of Softplus operator");
AddOutput("Y", "Output of Softplus operator"); AddOutput("Y", "Output of Softplus operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -334,9 +326,8 @@ $y = \ln(1 + e^{x})$ ...@@ -334,9 +326,8 @@ $y = \ln(1 + e^{x})$
class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker { class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftsignOpMaker(framework::OpProto *proto, SoftsignOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softsign operator"); AddInput("X", "Input of Softsign operator");
AddOutput("Y", "Output of Softsign operator"); AddOutput("Y", "Output of Softsign operator");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -350,8 +341,8 @@ $$y = \frac{x}{1 + |x|}$$ ...@@ -350,8 +341,8 @@ $$y = \frac{x}{1 + |x|}$$
class BReluOpMaker : public framework::OpProtoAndCheckerMaker { class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator"); AddInput("X", "Input of BRelu operator");
AddOutput("Y", "Output of BRelu operator"); AddOutput("Y", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu") AddAttr<float>("t_min", "The min marginal value of BRelu")
...@@ -369,9 +360,8 @@ $y = \max(\min(x, t_{min}), t_{max})$ ...@@ -369,9 +360,8 @@ $y = \max(\min(x, t_{min}), t_{max})$
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftReluOpMaker(framework::OpProto *proto, SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator"); AddInput("X", "Input of SoftRelu operator");
AddOutput("Y", "Output of SoftRelu operator"); AddOutput("Y", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu") AddAttr<float>("threshold", "The threshold value of SoftRelu")
...@@ -387,8 +377,8 @@ $y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$ ...@@ -387,8 +377,8 @@ $y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
class ELUOpMaker : public framework::OpProtoAndCheckerMaker { class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ELU operator"); AddInput("X", "Input of ELU operator");
AddOutput("Y", "Output of ELU operator"); AddOutput("Y", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f); AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
...@@ -406,8 +396,8 @@ $y = \max(0, x) + \min(0, \alpha * (e^x - 1))$ ...@@ -406,8 +396,8 @@ $y = \max(0, x) + \min(0, \alpha * (e^x - 1))$
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Relu6OpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu6 operator"); AddInput("X", "Input of Relu6 operator");
AddOutput("Y", "Output of Relu6 operator"); AddOutput("Y", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6") AddAttr<float>("threshold", "The threshold value of Relu6")
...@@ -423,8 +413,8 @@ $y = \min(\max(0, x), 6)$ ...@@ -423,8 +413,8 @@ $y = \min(\max(0, x), 6)$
class PowOpMaker : public framework::OpProtoAndCheckerMaker { class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PowOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator"); AddInput("X", "Input of Pow operator");
AddOutput("Y", "Output of Pow operator"); AddOutput("Y", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f); AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
...@@ -439,8 +429,8 @@ $y = x^{factor}$ ...@@ -439,8 +429,8 @@ $y = x^{factor}$
class STanhOpMaker : public framework::OpProtoAndCheckerMaker { class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
STanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator"); AddInput("X", "Input of STanh operator");
AddOutput("Y", "Output of STanh operator"); AddOutput("Y", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input") AddAttr<float>("scale_a", "The scale parameter of a for the input")
...@@ -458,9 +448,8 @@ $$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ ...@@ -458,9 +448,8 @@ $$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ThresholdedReluOpMaker(framework::OpProto *proto, ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ThresholdedRelu operator"); AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Y", "Output of ThresholdedRelu operator"); AddOutput("Y", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation") AddAttr<float>("threshold", "The threshold location of activation")
...@@ -481,9 +470,8 @@ $$ ...@@ -481,9 +470,8 @@ $$
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HardSigmoidOpMaker(framework::OpProto *proto, HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardSigmoid operator"); AddInput("X", "Input of HardSigmoid operator");
AddOutput("Y", "Output of HardSigmoid operator"); AddOutput("Y", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid") AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
...@@ -508,8 +496,8 @@ It is recommended to use the defaults for this activation. ...@@ -508,8 +496,8 @@ It is recommended to use the defaults for this activation.
class SwishOpMaker : public framework::OpProtoAndCheckerMaker { class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SwishOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Swish operator"); AddInput("X", "Input of Swish operator");
AddOutput("Y", "Output of Swish operator"); AddOutput("Y", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f); AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
......
...@@ -59,8 +59,7 @@ class AdadeltaOp : public framework::OperatorWithKernel { ...@@ -59,8 +59,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdadeltaOpMaker(framework::OpProto *proto, AdadeltaOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
......
...@@ -59,8 +59,7 @@ class AdagradOp : public framework::OperatorWithKernel { ...@@ -59,8 +59,7 @@ class AdagradOp : public framework::OperatorWithKernel {
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdagradOpMaker(framework::OpProto* proto, AdagradOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
......
...@@ -73,7 +73,7 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -73,7 +73,7 @@ class AdamOp : public framework::OperatorWithKernel {
class AdamOpMaker : public framework::OpProtoAndCheckerMaker { class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdamOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) AdamOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
......
...@@ -67,7 +67,7 @@ class AdamaxOp : public framework::OperatorWithKernel { ...@@ -67,7 +67,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker { class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdamaxOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) AdamaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
......
...@@ -114,8 +114,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -114,8 +114,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ArrayToLoDTensorOpProtoMaker(framework::OpProto *proto, ArrayToLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(std::vector<LodTensor>) A vector of tensors that is going to " "(std::vector<LodTensor>) A vector of tensors that is going to "
......
...@@ -86,8 +86,7 @@ class AssignOp : public framework::OperatorBase { ...@@ -86,8 +86,7 @@ class AssignOp : public framework::OperatorBase {
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AssignOpProtoMaker(framework::OpProto *proto, AssignOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor, SelectedRows or LoDTensorArray) The input variable " "(LoDTensor, SelectedRows or LoDTensorArray) The input variable "
...@@ -109,8 +108,8 @@ class AssignInferShape : public framework::InferShapeBase { ...@@ -109,8 +108,8 @@ class AssignInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
if (context->HasInput("X")) { if (context->HasInput("X")) {
auto type = context->GetInputsVarType("X")[0]; auto type = context->GetInputsVarType("X")[0];
if (type == framework::VarDesc_VarType_SELECTED_ROWS || if (type == framework::proto::VarDesc_VarType_SELECTED_ROWS ||
type == framework::VarDesc_VarType_LOD_TENSOR) { type == framework::proto::VarDesc_VarType_LOD_TENSOR) {
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
} }
} }
......
...@@ -49,7 +49,7 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -49,7 +49,7 @@ class AucOp : public framework::OperatorWithKernel {
class AucOpMaker : public framework::OpProtoAndCheckerMaker { class AucOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) AucOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Out", AddInput("Out",
"A floating point 2D tensor, values are in the range [0, 1]." "A floating point 2D tensor, values are in the range [0, 1]."
......
...@@ -85,8 +85,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -85,8 +85,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BatchNormOpMaker(framework::OpProto *proto, BatchNormOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<bool>("is_test", "").SetDefault(false); AddAttr<bool>("is_test", "").SetDefault(false);
AddAttr<float>("momentum", "").SetDefault(0.9); AddAttr<float>("momentum", "").SetDefault(0.9);
......
...@@ -83,9 +83,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -83,9 +83,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker { class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BeamSearchDecodeOpProtoMaker(framework::OpProto* proto, BeamSearchDecodeOpProtoMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ids", AddInput("Ids",
"(LodTensorArray)" "(LodTensorArray)"
"score of the candidate words in each step"); "score of the candidate words in each step");
...@@ -123,10 +122,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference { ...@@ -123,10 +122,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDescBind& op_desc, void operator()(const framework::OpDescBind& op_desc,
framework::BlockDescBind* block) const override { framework::BlockDescBind* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) { for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
} }
for (auto& o : op_desc.Output("SentenceScores")) { for (auto& o : op_desc.Output("SentenceScores")) {
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
} }
} }
}; };
......
...@@ -153,8 +153,7 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) { ...@@ -153,8 +153,7 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
class BeamSearchProtoAndCheckerMaker class BeamSearchProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
BeamSearchProtoAndCheckerMaker(framework::OpProto *proto, BeamSearchProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
// inputs and outputs stored in proto // inputs and outputs stored in proto
AddInput("pre_ids", "ids in previous step"); AddInput("pre_ids", "ids in previous step");
......
...@@ -65,8 +65,7 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,7 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BilinearTensorProductOpMaker(framework::OpProto* proto, BilinearTensorProductOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of bilinear_tensor_product operator."); AddInput("X", "The first input of bilinear_tensor_product operator.");
AddInput("Y", "The second input of bilinear_tensor_product operator."); AddInput("Y", "The second input of bilinear_tensor_product operator.");
......
...@@ -20,8 +20,7 @@ namespace operators { ...@@ -20,8 +20,7 @@ namespace operators {
class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker { class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CastOpProtoMaker(framework::OpProto *proto, CastOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of cast op"); AddInput("X", "The input tensor of cast op");
AddOutput("Out", "The output tensor of cast op"); AddOutput("Out", "The output tensor of cast op");
......
...@@ -55,7 +55,7 @@ class CastOpKernel : public framework::OpKernel<InT> { ...@@ -55,7 +55,7 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType( framework::VisitDataType(
static_cast<framework::DataType>(context.Attr<int>("out_dtype")), static_cast<framework::proto::DataType>(context.Attr<int>("out_dtype")),
CastOpFunctor<DeviceContext, InT>( CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>())); in, out, context.template device_context<DeviceContext>()));
} }
......
...@@ -57,15 +57,14 @@ class ChunkEvalOp : public framework::OperatorWithKernel { ...@@ -57,15 +57,14 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetKernelType( framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::DataType::FP32, return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context()); ctx.device_context());
} }
}; };
class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker { class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChunkEvalOpMaker(framework::OpProto *proto, ChunkEvalOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Inference", AddInput("Inference",
"(Tensor, default: Tensor<int64_t>). " "(Tensor, default: Tensor<int64_t>). "
......
...@@ -37,8 +37,7 @@ class ClipByNormOp : public framework::OperatorWithKernel { ...@@ -37,8 +37,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker { class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ClipByNormOpMaker(framework::OpProto* proto, ClipByNormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input of clip_by_norm op." "(Tensor) The input of clip_by_norm op."
......
...@@ -38,7 +38,7 @@ class ClipOp : public framework::OperatorWithKernel { ...@@ -38,7 +38,7 @@ class ClipOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker { class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ClipOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) ClipOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor)The input of clip op." "(Tensor)The input of clip op."
......
...@@ -20,8 +20,7 @@ namespace operators { ...@@ -20,8 +20,7 @@ namespace operators {
template <typename OpComment> template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker { class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CompareOpProtoMaker(framework::OpProto *proto, CompareOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment; OpComment comment;
AddInput("X", AddInput("X",
......
...@@ -58,7 +58,7 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -58,7 +58,7 @@ class ConcatOp : public framework::OperatorWithKernel {
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ConcatOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ConcatOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input tensors of concat operator.").AsDuplicable(); AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator."); AddOutput("Out", "Output tensor of concat operator.");
......
...@@ -205,8 +205,7 @@ void CondOp::Run(const Scope& scope, ...@@ -205,8 +205,7 @@ void CondOp::Run(const Scope& scope,
class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker { class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CondOpProtoAndCheckerMaker(framework::OpProto* proto, CondOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Cond", "The condition, which is a bool vector"); AddInput("Cond", "The condition, which is a bool vector");
AddInput("Xs", "Inputs of Subnets").AsDuplicable(); AddInput("Xs", "Inputs of Subnets").AsDuplicable();
......
...@@ -74,8 +74,7 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -74,8 +74,7 @@ class ConditionalBlockOp : public ConditionalOp {
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ConditionalBlockOpProtoMaker(framework::OpProto *proto, ConditionalBlockOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The conditional variable of this operator. If X is empty, the " "The conditional variable of this operator. If X is empty, the "
......
...@@ -19,8 +19,7 @@ namespace operators { ...@@ -19,8 +19,7 @@ namespace operators {
class CudnnConv2DOpMaker : public Conv2DOpMaker { class CudnnConv2DOpMaker : public Conv2DOpMaker {
public: public:
CudnnConv2DOpMaker(framework::OpProto* proto, CudnnConv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: Conv2DOpMaker(proto, op_checker) { : Conv2DOpMaker(proto, op_checker) {
AddAttr<int>("workspace_size_MB", AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, " "workspace size for cudnn, in MB, "
...@@ -34,8 +33,7 @@ class CudnnConv2DOpMaker : public Conv2DOpMaker { ...@@ -34,8 +33,7 @@ class CudnnConv2DOpMaker : public Conv2DOpMaker {
class CudnnConv3DOpMaker : public Conv3DOpMaker { class CudnnConv3DOpMaker : public Conv3DOpMaker {
public: public:
CudnnConv3DOpMaker(framework::OpProto* proto, CudnnConv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: Conv3DOpMaker(proto, op_checker) { : Conv3DOpMaker(proto, op_checker) {
AddAttr<int>("workspace_size_MB", AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, " "workspace size for cudnn, in MB, "
......
...@@ -66,8 +66,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -66,8 +66,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
} }
Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
...@@ -138,8 +137,7 @@ $$ ...@@ -138,8 +137,7 @@ $$
)DOC"); )DOC");
} }
Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
......
...@@ -50,14 +50,12 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim, ...@@ -50,14 +50,12 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
// operator implementations can reuse the code. // operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv2DOpMaker(framework::OpProto* proto, Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker);
framework::OpAttrChecker* op_checker);
}; };
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv3DOpMaker(framework::OpProto* proto, Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker);
framework::OpAttrChecker* op_checker);
}; };
class ConvOp : public framework::OperatorWithKernel { class ConvOp : public framework::OperatorWithKernel {
......
...@@ -75,8 +75,7 @@ class ConvShiftGradOp : public framework::OperatorWithKernel { ...@@ -75,8 +75,7 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker { class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ConvShiftOpMaker(framework::OpProto *proto, ConvShiftOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, " "(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
......
...@@ -19,8 +19,7 @@ namespace operators { ...@@ -19,8 +19,7 @@ namespace operators {
class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
public: public:
CudnnConv2DTransposeOpMaker(framework::OpProto* proto, CudnnConv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: Conv2DTransposeOpMaker(proto, op_checker) { : Conv2DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.") AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1}); .SetDefault({1, 1});
...@@ -36,8 +35,7 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { ...@@ -36,8 +35,7 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker { class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker {
public: public:
CudnnConv3DTransposeOpMaker(framework::OpProto* proto, CudnnConv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: Conv3DTransposeOpMaker(proto, op_checker) { : Conv3DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.") AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1, 1}); .SetDefault({1, 1, 1});
......
...@@ -53,8 +53,8 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -53,8 +53,8 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
} }
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
framework::OpProto* proto, framework::OpAttrChecker* op_checker) OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
...@@ -112,8 +112,8 @@ Example: ...@@ -112,8 +112,8 @@ Example:
)DOC"); )DOC");
} }
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto,
framework::OpProto* proto, framework::OpAttrChecker* op_checker) OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(Tensor) The input tensor of convolution transpose operator." "(Tensor) The input tensor of convolution transpose operator."
......
...@@ -30,14 +30,12 @@ using DDim = framework::DDim; ...@@ -30,14 +30,12 @@ using DDim = framework::DDim;
// operator implementations can reuse the code. // operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv2DTransposeOpMaker(framework::OpProto* proto, Conv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker);
framework::OpAttrChecker* op_checker);
}; };
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv3DTransposeOpMaker(framework::OpProto* proto, Conv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker);
framework::OpAttrChecker* op_checker);
}; };
class ConvTransposeOp : public framework::OperatorWithKernel { class ConvTransposeOp : public framework::OperatorWithKernel {
......
...@@ -62,7 +62,7 @@ class CosSimOp : public framework::OperatorWithKernel { ...@@ -62,7 +62,7 @@ class CosSimOp : public framework::OperatorWithKernel {
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CosSimOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) CosSimOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The 1st input of cos_sim op."); AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op."); AddInput("Y", "The 2nd input of cos_sim op.");
......
...@@ -18,8 +18,7 @@ namespace paddle { ...@@ -18,8 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker { class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CRFDecodingOpMaker(framework::OpProto* proto, CRFDecodingOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Emission", AddInput("Emission",
"(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape " "(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape "
......
...@@ -52,7 +52,7 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -52,7 +52,7 @@ class CropOp : public framework::OperatorWithKernel {
class CropOpMaker : public framework::OpProtoAndCheckerMaker { class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CropOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) CropOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input of pad op. " "The input of pad op. "
......
...@@ -111,8 +111,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -111,8 +111,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CrossEntropyOpMaker(framework::OpProto* proto, CrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, " "(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
......
...@@ -55,8 +55,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ...@@ -55,8 +55,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DecayedAdagradOpMaker(framework::OpProto *proto, DecayedAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
......
...@@ -40,8 +40,7 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -40,8 +40,7 @@ class DropoutOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DropoutOpMaker(framework::OpProto* proto, DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of dropout op."); AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op."); AddOutput("Out", "The output of dropout op.");
......
...@@ -19,8 +19,7 @@ namespace paddle { ...@@ -19,8 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
class ElementwiseAddOpMaker : public ElementwiseOpMaker { class ElementwiseAddOpMaker : public ElementwiseOpMaker {
public: public:
ElementwiseAddOpMaker(framework::OpProto* proto, ElementwiseAddOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) { : ElementwiseOpMaker(proto, op_checker) {
SetComment("Add", "$Out = X + Y$"); SetComment("Add", "$Out = X + Y$");
AddComment(comment_); AddComment(comment_);
......
...@@ -19,8 +19,7 @@ namespace paddle { ...@@ -19,8 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker { class ElementwiseDivOpMaker : public ElementwiseOpMaker {
public: public:
ElementwiseDivOpMaker(framework::OpProto* proto, ElementwiseDivOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) { : ElementwiseOpMaker(proto, op_checker) {
SetComment("Div", "$Out = X / Y$"); SetComment("Div", "$Out = X / Y$");
AddComment(comment_); AddComment(comment_);
......
...@@ -20,8 +20,7 @@ namespace operators { ...@@ -20,8 +20,7 @@ namespace operators {
class ElementwiseMulOpMaker : public ElementwiseOpMaker { class ElementwiseMulOpMaker : public ElementwiseOpMaker {
public: public:
ElementwiseMulOpMaker(framework::OpProto* proto, ElementwiseMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) { : ElementwiseOpMaker(proto, op_checker) {
SetComment("Mul", "$Out = X \\odot\\ Y$"); SetComment("Mul", "$Out = X \\odot\\ Y$");
AddComment(comment_); AddComment(comment_);
......
...@@ -43,8 +43,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -43,8 +43,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ElementwiseOpMaker(framework::OpProto* proto, ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The first input tensor of elementwise op"); AddInput("X", "(Tensor) The first input tensor of elementwise op");
AddInput("Y", "(Tensor) The second input tensor of elementwise op"); AddInput("Y", "(Tensor) The second input tensor of elementwise op");
......
...@@ -19,8 +19,7 @@ namespace paddle { ...@@ -19,8 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
class ElementwiseSubOpMaker : public ElementwiseOpMaker { class ElementwiseSubOpMaker : public ElementwiseOpMaker {
public: public:
ElementwiseSubOpMaker(framework::OpProto* proto, ElementwiseSubOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) { : ElementwiseOpMaker(proto, op_checker) {
SetComment("Sub", "$Out = X - Y$"); SetComment("Sub", "$Out = X - Y$");
AddComment(comment_); AddComment(comment_);
......
...@@ -55,7 +55,7 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -55,7 +55,7 @@ class ExpandOp : public framework::OperatorWithKernel {
class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
......
...@@ -54,8 +54,7 @@ class FeedOp : public framework::OperatorBase { ...@@ -54,8 +54,7 @@ class FeedOp : public framework::OperatorBase {
class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker { class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FeedOpInfoMaker(framework::OpProto *proto, FeedOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of feed op"); AddInput("X", "The input of feed op");
AddOutput("Out", "The output of feed op"); AddOutput("Out", "The output of feed op");
......
...@@ -61,8 +61,7 @@ class FetchOp : public framework::OperatorBase { ...@@ -61,8 +61,7 @@ class FetchOp : public framework::OperatorBase {
class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker { class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FetchOpInfoMaker(framework::OpProto *proto, FetchOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fetch op"); AddInput("X", "The input of fetch op");
AddOutput("Out", "The output of fetch op"); AddOutput("Out", "The output of fetch op");
......
...@@ -52,7 +52,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { ...@@ -52,7 +52,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType( framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -60,13 +60,12 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { ...@@ -60,13 +60,12 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
class FillConstantBatchSizeLikeOpMaker class FillConstantBatchSizeLikeOpMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
FillConstantBatchSizeLikeOpMaker(framework::OpProto *proto, FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
AddInput("Input", AddInput("Input",
"(Tensor) Tensor " "(Tensor) Tensor "
"whose dim_idx th dimension is used to specify the batch_size"); "whose dim_idx th dimension is used to specify the batch_size");
......
...@@ -34,7 +34,8 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -34,7 +34,8 @@ class FillConstantOp : public framework::OperatorBase {
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::DeviceContext &dev_ctx) const override {
auto data_type = static_cast<framework::DataType>(Attr<int>("dtype")); auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype"));
auto value = Attr<float>("value"); auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu"); auto force_cpu = Attr<bool>("force_cpu");
auto &out = auto &out =
...@@ -52,13 +53,12 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -52,13 +53,12 @@ class FillConstantOp : public framework::OperatorBase {
class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FillConstantOpMaker(framework::OpProto *proto, FillConstantOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output"); AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled") AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);
......
...@@ -48,7 +48,7 @@ class FillOp : public framework::OperatorBase { ...@@ -48,7 +48,7 @@ class FillOp : public framework::OperatorBase {
"Cannot find variable %s", Output("Out")) "Cannot find variable %s", Output("Out"))
.GetMutable<framework::LoDTensor>()); .GetMutable<framework::LoDTensor>());
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape"))); out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
auto dtype = static_cast<framework::DataType>(Attr<int>("dtype")); auto dtype = static_cast<framework::proto::DataType>(Attr<int>("dtype"));
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu"); auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : dev_ctx.GetPlace(), out.mutable_data(force_cpu ? cpu : dev_ctx.GetPlace(),
...@@ -76,7 +76,7 @@ class FillOp : public framework::OperatorBase { ...@@ -76,7 +76,7 @@ class FillOp : public framework::OperatorBase {
class FillOpMaker : public framework::OpProtoAndCheckerMaker { class FillOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FillOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) FillOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddComment(R"DOC(Fill operator AddComment(R"DOC(Fill operator
...@@ -88,7 +88,7 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by ...@@ -88,7 +88,7 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by
"value", "The float values of tensor, which are flatten in row major"); "value", "The float values of tensor, which are flatten in row major");
AddAttr<std::vector<int>>("shape", "The shape of output tensor"); AddAttr<std::vector<int>>("shape", "The shape of output tensor");
AddAttr<int>("dtype", "The data type of output tensor, Default is float") AddAttr<int>("dtype", "The data type of output tensor, Default is float")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
AddAttr<bool>("force_cpu", AddAttr<bool>("force_cpu",
"Whether the output tensor must be at CPU memory or not. " "Whether the output tensor must be at CPU memory or not. "
"Default is false.") "Default is false.")
......
...@@ -33,8 +33,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { ...@@ -33,8 +33,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FillZerosLikeOpMaker(framework::OpProto *proto, FillZerosLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fill-zeros-like op."); AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Y", "The variable will be filled up with zeros."); AddOutput("Y", "The variable will be filled up with zeros.");
......
...@@ -57,7 +57,7 @@ class FTRLOp : public framework::OperatorWithKernel { ...@@ -57,7 +57,7 @@ class FTRLOp : public framework::OperatorWithKernel {
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker { class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FTRLOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) FTRLOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
......
...@@ -67,7 +67,7 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -67,7 +67,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
class GatherOpMaker : public framework::OpProtoAndCheckerMaker { class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GatherOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) GatherOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of gather op"); AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op"); AddInput("Index", "The index input of gather op");
......
...@@ -60,15 +60,14 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -60,15 +60,14 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType( framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
ctx.device_context()); ctx.device_context());
} }
}; };
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GaussianRandomOpMaker(framework::OpProto* proto, GaussianRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "Output matrix of gaussian random op"); AddOutput("Out", "Output matrix of gaussian random op");
...@@ -91,7 +90,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -91,7 +90,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5(FP32)) " "(int, default 5(FP32)) "
"Output data type.") "Output data type.")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
AddComment(R"DOC( AddComment(R"DOC(
GaussianRandom Operator. GaussianRandom Operator.
......
...@@ -67,7 +67,7 @@ class GRUOp : public framework::OperatorWithKernel { ...@@ -67,7 +67,7 @@ class GRUOp : public framework::OperatorWithKernel {
class GRUOpMaker : public framework::OpProtoAndCheckerMaker { class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GRUOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) GRUOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(LoDTensor) The first input is a LodTensor, which supports " "(LoDTensor) The first input is a LodTensor, which supports "
......
...@@ -71,8 +71,7 @@ class GRUUnitOp : public framework::OperatorWithKernel { ...@@ -71,8 +71,7 @@ class GRUUnitOp : public framework::OperatorWithKernel {
class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GRUUnitOpMaker(framework::OpProto* proto, GRUUnitOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
......
...@@ -46,8 +46,7 @@ class HingeLossOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,7 @@ class HingeLossOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class HingeLossOpMaker : public framework::OpProtoAndCheckerMaker { class HingeLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HingeLossOpMaker(framework::OpProto* proto, HingeLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits", AddInput("Logits",
"The input value (Logits) of Hinge loss op." "The input value (Logits) of Hinge loss op."
......
...@@ -45,8 +45,7 @@ class HuberLossOp : public framework::OperatorWithKernel { ...@@ -45,8 +45,7 @@ class HuberLossOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker { class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HuberLossOpMaker(framework::OpProto* proto, HuberLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input value of huber loss op." "The input value of huber loss op."
......
...@@ -70,8 +70,7 @@ class IncrementOp : public framework::OperatorBase { ...@@ -70,8 +70,7 @@ class IncrementOp : public framework::OperatorBase {
class IncrementOpMaker : public framework::OpProtoAndCheckerMaker { class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
IncrementOpMaker(framework::OpProto *proto, IncrementOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input tensor of increment operator"); AddInput("X", "(Tensor) The input tensor of increment operator");
AddOutput("Out", "(Tensor) The output tensor of increment operator."); AddOutput("Out", "(Tensor) The output tensor of increment operator.");
......
...@@ -47,8 +47,7 @@ class IsEmptyOp : public framework::OperatorBase { ...@@ -47,8 +47,7 @@ class IsEmptyOp : public framework::OperatorBase {
class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker { class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
IsEmptyOpProtoMaker(framework::OpProto *proto, IsEmptyOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kInput, "(Tensor) Tensor which is to be checked."); AddInput(kInput, "(Tensor) Tensor which is to be checked.");
AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not."); AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not.");
......
...@@ -48,7 +48,7 @@ class L1NormGradOp : public framework::OperatorWithKernel { ...@@ -48,7 +48,7 @@ class L1NormGradOp : public framework::OperatorWithKernel {
class L1NormOpMaker : public framework::OpProtoAndCheckerMaker { class L1NormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
L1NormOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) L1NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of l1_norm op."); AddInput("X", "(Tensor) The input of l1_norm op.");
AddOutput("Out", "(Scalar) The output of l1_norm op."); AddOutput("Out", "(Scalar) The output of l1_norm op.");
......
...@@ -19,8 +19,7 @@ namespace operators { ...@@ -19,8 +19,7 @@ namespace operators {
class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker { class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LinearChainCRFOpMaker(framework::OpProto* proto, LinearChainCRFOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Emission", AddInput("Emission",
"(LoDTensor, default LoDTensor<float>) " "(LoDTensor, default LoDTensor<float>) "
......
...@@ -58,8 +58,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -58,8 +58,7 @@ class LoadOp : public framework::OperatorBase {
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoadOpProtoMaker(framework::OpProto *proto, LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) The tensor need to be loaded"); AddOutput("Out", "(Tensor) The tensor need to be loaded");
AddAttr<std::string>("file_path", AddAttr<std::string>("file_path",
......
...@@ -38,8 +38,7 @@ class LoDArrayLengthOp : public framework::OperatorBase { ...@@ -38,8 +38,7 @@ class LoDArrayLengthOp : public framework::OperatorBase {
class LoDArrayLengthProtoMaker : public framework::OpProtoAndCheckerMaker { class LoDArrayLengthProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoDArrayLengthProtoMaker(framework::OpProto *proto, LoDArrayLengthProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensorArray) The input tensor array."); AddInput("X", "(LoDTensorArray) The input tensor array.");
AddOutput("Out", "(Tensor) 1x1 CPU Tensor of length, int64_t"); AddOutput("Out", "(Tensor) 1x1 CPU Tensor of length, int64_t");
......
...@@ -30,13 +30,13 @@ class LoDRankTableOp : public framework::OperatorBase { ...@@ -30,13 +30,13 @@ class LoDRankTableOp : public framework::OperatorBase {
scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>(); scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>();
VLOG(10) << "Level = " << static_cast<size_t>(Attr<int>("level")); VLOG(10) << "Level = " << static_cast<size_t>(Attr<int>("level"));
out->Reset(x.lod(), static_cast<size_t>(Attr<int>("level"))); out->Reset(x.lod(), static_cast<size_t>(Attr<int>("level")));
VLOG(10) << Input("X") << "'s lod information is " << *out;
} }
}; };
class LoDRankTableOpProtoMaker : public framework::OpProtoAndCheckerMaker { class LoDRankTableOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoDRankTableOpProtoMaker(framework::OpProto *proto, LoDRankTableOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor) input lod tensor, must contain lod information."); "(LoDTensor) input lod tensor, must contain lod information.");
...@@ -67,7 +67,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference { ...@@ -67,7 +67,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference {
framework::BlockDescBind *block) const override { framework::BlockDescBind *block) const override {
for (auto &o : op_desc.Output("Out")) { for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o)->SetType( block->FindRecursiveOrCreateVar(o)->SetType(
framework::VarDesc::LOD_RANK_TABLE); framework::proto::VarDesc::LOD_RANK_TABLE);
} }
} }
}; };
......
...@@ -48,8 +48,7 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -48,8 +48,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoDResetOpMaker(framework::OpProto *proto, LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The input tensor of lod_reset operator."); AddInput("X", "(LoDTensor) The input tensor of lod_reset operator.");
AddInput("TargetLoD", AddInput("TargetLoD",
......
...@@ -97,8 +97,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase { ...@@ -97,8 +97,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
class LoDTensorToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker { class LoDTensorToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoDTensorToArrayOpProtoMaker(framework::OpProto *proto, LoDTensorToArrayOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", ""); AddInput("X", "");
AddInput("RankTable", ""); AddInput("RankTable", "");
...@@ -131,7 +130,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference { ...@@ -131,7 +130,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDescBind &op_desc, void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override { framework::BlockDescBind *block) const override {
for (auto &out_var : op_desc.Output("Out")) { for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY); block->Var(out_var)->SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY);
} }
} }
}; };
......
...@@ -46,8 +46,7 @@ class LogLossOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,7 @@ class LogLossOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class LogLossOpMaker : public framework::OpProtoAndCheckerMaker { class LogLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LogLossOpMaker(framework::OpProto* proto, LogLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Predicted", AddInput("Predicted",
"The input value (Predicted) of Log loss op." "The input value (Predicted) of Log loss op."
......
...@@ -20,8 +20,7 @@ namespace operators { ...@@ -20,8 +20,7 @@ namespace operators {
template <typename OpComment> template <typename OpComment>
class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BinaryLogicalOpProtoMaker(framework::OpProto *proto, BinaryLogicalOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment; OpComment comment;
AddInput("X", AddInput("X",
...@@ -45,8 +44,7 @@ Each element of Out is calculated by %s ...@@ -45,8 +44,7 @@ Each element of Out is calculated by %s
template <typename OpComment> template <typename OpComment>
class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
UnaryLogicalOpProtoMaker(framework::OpProto *proto, UnaryLogicalOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment; OpComment comment;
AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator", AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator",
......
...@@ -51,8 +51,7 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -51,8 +51,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LookupTableOpMaker(framework::OpProto* proto, LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W", AddInput("W",
"An input represents embedding tensors, " "An input represents embedding tensors, "
...@@ -117,11 +116,12 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { ...@@ -117,11 +116,12 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
if (is_sparse) { if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(out_var_name)->SetType(framework::VarDesc::SELECTED_ROWS); block->Var(out_var_name)
->SetType(framework::proto::VarDesc::SELECTED_ROWS);
} else { } else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::VarDesc::LOD_TENSOR); block->Var(out_var_name)->SetType(framework::proto::VarDesc::LOD_TENSOR);
} }
} }
}; };
......
...@@ -140,7 +140,7 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -140,7 +140,7 @@ class LRNOp : public framework::OperatorWithKernel {
template <typename T> template <typename T>
class LRNOpMaker : public framework::OpProtoAndCheckerMaker { class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LRNOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) LRNOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input of LRN operator. " "(Tensor) The input of LRN operator. "
......
...@@ -102,7 +102,7 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -102,7 +102,7 @@ class LSTMOp : public framework::OperatorWithKernel {
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) LSTMOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support " "(LoDTensor) the first input is a LodTensor, which support "
......
...@@ -48,8 +48,7 @@ class LstmUnitOp : public framework::OperatorWithKernel { ...@@ -48,8 +48,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LstmUnitOpMaker(framework::OpProto* proto, LstmUnitOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"Lstm unit only applies non-linear activations, please make sure" "Lstm unit only applies non-linear activations, please make sure"
......
...@@ -42,8 +42,7 @@ class MarginRankLossOp : public framework::OperatorWithKernel { ...@@ -42,8 +42,7 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
template <typename T> template <typename T>
class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker { class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MarginRankLossOpMaker(framework::OpProto *proto, MarginRankLossOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X1", AddInput("X1",
"(2-D tensor with shape [batch_size x 1]) The score for " "(2-D tensor with shape [batch_size x 1]) The score for "
......
...@@ -130,7 +130,7 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -130,7 +130,7 @@ class MatMulOp : public framework::OperatorWithKernel {
class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MatMulOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) MatMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of MatMul op"); AddInput("X", "The first input of MatMul op");
AddInput("Y", "The second input of MatMul op"); AddInput("Y", "The second input of MatMul op");
......
...@@ -40,8 +40,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase { ...@@ -40,8 +40,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
class MaxSeqenceLenOpProtoMaker : public framework::OpProtoAndCheckerMaker { class MaxSeqenceLenOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MaxSeqenceLenOpProtoMaker(framework::OpProto *proto, MaxSeqenceLenOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("RankTable", "The lod_rank_table."); AddInput("RankTable", "The lod_rank_table.");
AddOutput("Out", "The max sequence length."); AddOutput("Out", "The max sequence length.");
......
...@@ -20,7 +20,7 @@ using framework::Tensor; ...@@ -20,7 +20,7 @@ using framework::Tensor;
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) MaxOutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
......
...@@ -32,7 +32,7 @@ class MeanOp : public framework::OperatorWithKernel { ...@@ -32,7 +32,7 @@ class MeanOp : public framework::OperatorWithKernel {
class MeanOpMaker : public framework::OpProtoAndCheckerMaker { class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) MeanOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op"); AddOutput("Out", "The output of mean op");
......
...@@ -114,8 +114,7 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -114,8 +114,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MergeLoDTensorOpProtoMaker(framework::OpProto *proto, MergeLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input LoDTensor, contains complete lod information to " "The input LoDTensor, contains complete lod information to "
......
...@@ -46,7 +46,7 @@ class MinusOp : public framework::OperatorWithKernel { ...@@ -46,7 +46,7 @@ class MinusOp : public framework::OperatorWithKernel {
class MinusOpMaker : public framework::OpProtoAndCheckerMaker { class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MinusOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MinusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left tensor of minus operator."); AddInput("X", "The left tensor of minus operator.");
AddInput("Y", "The right tensor of minus operator."); AddInput("Y", "The right tensor of minus operator.");
......
...@@ -39,8 +39,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel { ...@@ -39,8 +39,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker { class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ModifiedHuberLossOpMaker(framework::OpProto* proto, ModifiedHuberLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input tensor of modified huber loss op. " "The input tensor of modified huber loss op. "
......
...@@ -54,8 +54,7 @@ class MomentumOp : public framework::OperatorWithKernel { ...@@ -54,8 +54,7 @@ class MomentumOp : public framework::OperatorWithKernel {
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MomentumOpMaker(framework::OpProto *proto, MomentumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
......
...@@ -71,7 +71,7 @@ class MulOpShapeInference : public framework::InferShapeBase { ...@@ -71,7 +71,7 @@ class MulOpShapeInference : public framework::InferShapeBase {
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MulOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) MulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op"); AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op"); AddInput("Y", "The second input of mul op");
......
...@@ -61,8 +61,7 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -61,8 +61,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MultiplexOpMaker(framework::OpProto* proto, MultiplexOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ids", "The index tensor of multiplex operator."); AddInput("Ids", "The index tensor of multiplex operator.");
AddInput("X", "The candidate tensors of multiplex operator.") AddInput("X", "The candidate tensors of multiplex operator.")
......
...@@ -35,8 +35,8 @@ Here we give some examples to show how these rules will be used. ...@@ -35,8 +35,8 @@ Here we give some examples to show how these rules will be used.
```c++ ```c++
class AccumulateOpMaker : public framework::OpProtoAndCheckerMaker { class AccumulateOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AccumulateOpMaker(framework::OpProto *proto, AccumulateOpMaker(OpProto *proto,
framework::OpAttrChecker *op_checker) OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input tensor that has to be accumulated to the output tensor. AddInput("X", "(Tensor) The input tensor that has to be accumulated to the output tensor.
If the output size is not the same as input size, If the output size is not the same as input size,
......
...@@ -43,8 +43,7 @@ class NCCLInitOp : public framework::OperatorBase { ...@@ -43,8 +43,7 @@ class NCCLInitOp : public framework::OperatorBase {
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLInitOpMaker(framework::OpProto *proto, NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Communicator", AddOutput("Communicator",
"Create Communicator for communicating between gpus"); "Create Communicator for communicating between gpus");
...@@ -52,7 +51,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -52,7 +51,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
AddComment(R"DOC( AddComment(R"DOC(
NCCLInit Operator. NCCLInit Operator.
...@@ -141,8 +140,7 @@ class NCCLBcastOp : public framework::OperatorWithKernel { ...@@ -141,8 +140,7 @@ class NCCLBcastOp : public framework::OperatorWithKernel {
// AllreduceOp // AllreduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLAllReduceOpMaker(framework::OpProto *proto, NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op"); AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
...@@ -163,8 +161,7 @@ AllReduce the input tensors. ...@@ -163,8 +161,7 @@ AllReduce the input tensors.
// ReduceOp // ReduceOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLReduceOpMaker(framework::OpProto *proto, NCCLReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of Reduce op"); AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
...@@ -190,8 +187,7 @@ Reduce the tensors. ...@@ -190,8 +187,7 @@ Reduce the tensors.
// BcastOp // BcastOp
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLBcastOpMaker(framework::OpProto *proto, NCCLBcastOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op"); AddInput("X", "The input of BcastSend op");
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
......
...@@ -73,7 +73,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -73,7 +73,7 @@ class NCEOp : public framework::OperatorWithKernel {
class NCEOpMaker : public framework::OpProtoAndCheckerMaker { class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) NCEOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim]."); AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
AddInput( AddInput(
......
...@@ -48,7 +48,7 @@ class PadOp : public framework::OperatorWithKernel { ...@@ -48,7 +48,7 @@ class PadOp : public framework::OperatorWithKernel {
class PadOpMaker : public framework::OpProtoAndCheckerMaker { class PadOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PadOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) PadOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input of pad op. " "The input of pad op. "
......
...@@ -67,8 +67,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { ...@@ -67,8 +67,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
...@@ -136,8 +135,7 @@ Example: ...@@ -136,8 +135,7 @@ Example:
)DOC"); )DOC");
} }
Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto, Pool3dOpMaker::Pool3dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of pooling operator. " "(Tensor) The input tensor of pooling operator. "
......
...@@ -40,14 +40,12 @@ class PoolOpGrad : public framework::OperatorWithKernel { ...@@ -40,14 +40,12 @@ class PoolOpGrad : public framework::OperatorWithKernel {
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Pool2dOpMaker(framework::OpProto* proto, Pool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
framework::OpAttrChecker* op_checker);
}; };
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Pool3dOpMaker(framework::OpProto* proto, Pool3dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
framework::OpAttrChecker* op_checker);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -100,8 +100,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { ...@@ -100,8 +100,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MaxPool2dWithIndexOpMaker(framework::OpProto *proto, MaxPool2dWithIndexOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
...@@ -178,8 +177,7 @@ Example: ...@@ -178,8 +177,7 @@ Example:
class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MaxPool3dWithIndexOpMaker(framework::OpProto *proto, MaxPool3dWithIndexOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of pooling operator. " "(Tensor) The input tensor of pooling operator. "
......
...@@ -95,8 +95,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { ...@@ -95,8 +95,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker { class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PositiveNegativePairOpMaker(framework::OpProto *proto, PositiveNegativePairOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Score", AddInput("Score",
"(Tensor, float) Model Score on an item (with " "(Tensor, float) Model Score on an item (with "
......
...@@ -90,8 +90,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -90,8 +90,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PrecisionRecallOpMaker(framework::OpProto *proto, PrecisionRecallOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("MaxProbs", AddInput("MaxProbs",
"(Tensor, default Tensor<float>) A 2-D tensor with shape N x 1, " "(Tensor, default Tensor<float>) A 2-D tensor with shape N x 1, "
......
...@@ -38,7 +38,7 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -38,7 +38,7 @@ class PReluOp : public framework::OperatorWithKernel {
class PReluOpMaker : public framework::OpProtoAndCheckerMaker { class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) PReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of prelu operator."); AddInput("X", "The input tensor of prelu operator.");
AddInput("Alpha", "The alpha weight of prelu operator."); AddInput("Alpha", "The alpha weight of prelu operator.");
......
...@@ -59,8 +59,7 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ...@@ -59,8 +59,7 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ProximalAdagradOpMaker(framework::OpProto *proto, ProximalAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
......
...@@ -47,8 +47,7 @@ class ProximalGDOp : public framework::OperatorWithKernel { ...@@ -47,8 +47,7 @@ class ProximalGDOp : public framework::OperatorWithKernel {
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker { class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ProximalGDOpMaker(framework::OpProto *proto, ProximalGDOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
......
...@@ -45,8 +45,7 @@ class RankLossOp : public framework::OperatorWithKernel { ...@@ -45,8 +45,7 @@ class RankLossOp : public framework::OperatorWithKernel {
class RankLossOpMaker : public framework::OpProtoAndCheckerMaker { class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RankLossOpMaker(framework::OpProto *proto, RankLossOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Label", AddInput("Label",
"(2-D Tensor with shape [batch_size x 1]) " "(2-D Tensor with shape [batch_size x 1]) "
......
...@@ -497,8 +497,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -497,8 +497,7 @@ class RecurrentGradOp : public RecurrentBase {
class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker { class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RecurrentOpProtoMaker(framework::OpProto *proto, RecurrentOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kInputs, "rnn inputs").AsDuplicable(); AddInput(kInputs, "rnn inputs").AsDuplicable();
AddInput(kInitialStates, "rnn initial states").AsDuplicable(); AddInput(kInitialStates, "rnn initial states").AsDuplicable();
......
...@@ -97,7 +97,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -97,7 +97,7 @@ class RecvOp : public framework::OperatorBase {
class RecvOpMaker : public framework::OpProtoAndCheckerMaker { class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RecvOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) RecvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("RX", "(Tensor) Input tensor to be saved"); AddInput("RX", "(Tensor) Input tensor to be saved");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -83,7 +83,7 @@ class ReduceGradOp : public framework::OperatorWithKernel { ...@@ -83,7 +83,7 @@ class ReduceGradOp : public framework::OperatorWithKernel {
class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor. Tensors with rank at most 6 are " "(Tensor) The input tensor. Tensors with rank at most 6 are "
...@@ -135,8 +135,7 @@ If reduce_all is true, just reduce along all dimensions and output a scalar. ...@@ -135,8 +135,7 @@ If reduce_all is true, just reduce along all dimensions and output a scalar.
class ReduceSumOpMaker : public ReduceOpMaker { class ReduceSumOpMaker : public ReduceOpMaker {
public: public:
ReduceSumOpMaker(framework::OpProto *proto, ReduceSumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) { : ReduceOpMaker(proto, op_checker) {
SetComment("ReduceSum", "sum"); SetComment("ReduceSum", "sum");
AddComment(comment_); AddComment(comment_);
...@@ -145,8 +144,7 @@ class ReduceSumOpMaker : public ReduceOpMaker { ...@@ -145,8 +144,7 @@ class ReduceSumOpMaker : public ReduceOpMaker {
class ReduceMeanOpMaker : public ReduceOpMaker { class ReduceMeanOpMaker : public ReduceOpMaker {
public: public:
ReduceMeanOpMaker(framework::OpProto *proto, ReduceMeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) { : ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMean", "mean"); SetComment("ReduceMean", "mean");
AddComment(comment_); AddComment(comment_);
...@@ -155,8 +153,7 @@ class ReduceMeanOpMaker : public ReduceOpMaker { ...@@ -155,8 +153,7 @@ class ReduceMeanOpMaker : public ReduceOpMaker {
class ReduceMaxOpMaker : public ReduceOpMaker { class ReduceMaxOpMaker : public ReduceOpMaker {
public: public:
ReduceMaxOpMaker(framework::OpProto *proto, ReduceMaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) { : ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMax", "max"); SetComment("ReduceMax", "max");
AddComment(comment_); AddComment(comment_);
...@@ -165,8 +162,7 @@ class ReduceMaxOpMaker : public ReduceOpMaker { ...@@ -165,8 +162,7 @@ class ReduceMaxOpMaker : public ReduceOpMaker {
class ReduceMinOpMaker : public ReduceOpMaker { class ReduceMinOpMaker : public ReduceOpMaker {
public: public:
ReduceMinOpMaker(framework::OpProto *proto, ReduceMinOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) { : ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMin", "min"); SetComment("ReduceMin", "min");
AddComment(comment_); AddComment(comment_);
......
...@@ -77,8 +77,7 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -77,8 +77,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ReshapeOpMaker(framework::OpProto *proto, ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of reshape operator."); AddInput("X", "The input tensor of reshape operator.");
AddOutput("Out", "The output tensor of reshape operator."); AddOutput("Out", "The output tensor of reshape operator.");
......
...@@ -63,8 +63,7 @@ class RmspropOp : public framework::OperatorWithKernel { ...@@ -63,8 +63,7 @@ class RmspropOp : public framework::OperatorWithKernel {
class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RmspropOpMaker(framework::OpProto *proto, RmspropOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
......
...@@ -57,15 +57,14 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase { ...@@ -57,15 +57,14 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker { class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RNNMemoryHelperOpInfoMaker(framework::OpProto *proto, RNNMemoryHelperOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", ""); AddInput("X", "");
AddOutput("Out", ""); AddOutput("Out", "");
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
AddComment(""); AddComment("");
} }
}; };
...@@ -114,8 +113,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { ...@@ -114,8 +113,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
class RNNMemoryHelperGradOpInfoMaker class RNNMemoryHelperGradOpInfoMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
RNNMemoryHelperGradOpInfoMaker(framework::OpProto *proto, RNNMemoryHelperGradOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(framework::GradVarName("Out"), ""); AddInput(framework::GradVarName("Out"), "");
AddInput("X", ""); AddInput("X", "");
...@@ -124,7 +122,7 @@ class RNNMemoryHelperGradOpInfoMaker ...@@ -124,7 +122,7 @@ class RNNMemoryHelperGradOpInfoMaker
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
AddComment(""); AddComment("");
} }
}; };
......
...@@ -99,8 +99,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { ...@@ -99,8 +99,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ROIPoolOpMaker(framework::OpProto* proto, ROIPoolOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor), " "(Tensor), "
......
...@@ -76,8 +76,7 @@ class RowConvGradOp : public framework::OperatorWithKernel { ...@@ -76,8 +76,7 @@ class RowConvGradOp : public framework::OperatorWithKernel {
class RowConvOpMaker : public framework::OpProtoAndCheckerMaker { class RowConvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RowConvOpMaker(framework::OpProto *proto, RowConvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor), the input(X) is a LodTensor, which supports " "(LoDTensor), the input(X) is a LodTensor, which supports "
......
...@@ -94,8 +94,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -94,8 +94,7 @@ class SaveOp : public framework::OperatorBase {
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SaveOpProtoMaker(framework::OpProto *proto, SaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor ) Input tensor to be saved"); AddInput("X", "(Tensor ) Input tensor to be saved");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -38,7 +38,7 @@ class ScaleOp : public framework::OperatorWithKernel { ...@@ -38,7 +38,7 @@ class ScaleOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of scale operator."); AddInput("X", "(Tensor) Input tensor of scale operator.");
AddOutput("Out", "(Tensor) Output tensor of scale operator."); AddOutput("Out", "(Tensor) Output tensor of scale operator.");
......
...@@ -78,8 +78,7 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -78,8 +78,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScatterOpMaker(framework::OpProto* proto, ScatterOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ref", "The source input of scatter op"); AddInput("Ref", "The source input of scatter op");
AddInput("Index", AddInput("Index",
......
...@@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase { ...@@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase {
class SendOpMaker : public framework::OpProtoAndCheckerMaker { class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SendOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SendOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor to be saved"); AddInput("X", "(Tensor) Input tensor to be saved");
AddOutput("Out", "(Tensor) Output fetched from server"); AddOutput("Out", "(Tensor) Output fetched from server");
......
...@@ -43,8 +43,7 @@ class SequenceConcatOp : public framework::OperatorWithKernel { ...@@ -43,8 +43,7 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceConcatOpMaker(framework::OpProto* proto, SequenceConcatOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LodTensorArray) Input is a vector of LoDTensor, " "(LodTensorArray) Input is a vector of LoDTensor, "
......
...@@ -100,8 +100,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel { ...@@ -100,8 +100,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceConvOpMaker(framework::OpProto* proto, SequenceConvOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
......
...@@ -37,8 +37,7 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -37,8 +37,7 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceExpandOpMaker(framework::OpProto* proto, SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor or LoDTensor) The input(X) of this operator can be a " "(Tensor or LoDTensor) The input(X) of this operator can be a "
......
...@@ -37,8 +37,7 @@ class SequencePoolOp : public framework::OperatorWithKernel { ...@@ -37,8 +37,7 @@ class SequencePoolOp : public framework::OperatorWithKernel {
class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequencePoolOpMaker(framework::OpProto* proto, SequencePoolOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The variable-length input of SequencePoolOp"); AddInput("X", "(LoDTensor) The variable-length input of SequencePoolOp");
AddOutput("Out", AddOutput("Out",
......
...@@ -79,8 +79,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { ...@@ -79,8 +79,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceSliceOpMaker(framework::OpProto* proto, SequenceSliceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor), " "(LoDTensor), "
......
...@@ -33,8 +33,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -33,8 +33,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceSoftmaxOpMaker(framework::OpProto* proto, SequenceSoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension " "(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension "
......
...@@ -43,7 +43,7 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -43,7 +43,7 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpMaker : public framework::OpProtoAndCheckerMaker { class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of SGD"); AddInput("LearningRate", "(Tensor) Learning rate of SGD");
......
...@@ -54,8 +54,7 @@ class ShrinkRNNMemoryOp : public ArrayOp { ...@@ -54,8 +54,7 @@ class ShrinkRNNMemoryOp : public ArrayOp {
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ShrinkRNNMemoryOpProtoMaker(framework::OpProto *proto, ShrinkRNNMemoryOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The RNN step memory to be shrinked."); AddInput("X", "(LoDTensor) The RNN step memory to be shrinked.");
AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN."); AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN.");
......
...@@ -86,8 +86,8 @@ class SigmoidCrossEntropyWithLogitsGradOp ...@@ -86,8 +86,8 @@ class SigmoidCrossEntropyWithLogitsGradOp
class SigmoidCrossEntropyWithLogitsOpMaker class SigmoidCrossEntropyWithLogitsOpMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
SigmoidCrossEntropyWithLogitsOpMaker(framework::OpProto* proto, SigmoidCrossEntropyWithLogitsOpMaker(OpProto* proto,
framework::OpAttrChecker* op_checker) OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, " "(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
......
...@@ -34,7 +34,7 @@ class SignOp : public framework::OperatorWithKernel { ...@@ -34,7 +34,7 @@ class SignOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class SignOpMaker : public framework::OpProtoAndCheckerMaker { class SignOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SignOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SignOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of sign operator."); AddInput("X", "(Tensor) Input tensor of sign operator.");
AddOutput("Out", "(Tensor) Output tensor of sign operator."); AddOutput("Out", "(Tensor) Output tensor of sign operator.");
......
...@@ -47,8 +47,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { ...@@ -47,8 +47,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SmoothL1LossOpMaker(framework::OpProto* proto, SmoothL1LossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>) A tensor with rank at least 2. " "(Tensor, default Tensor<float>) A tensor with rank at least 2. "
......
...@@ -36,8 +36,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -36,8 +36,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftmaxOpMaker(framework::OpProto* proto, SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input tensor of softmax. " "The input tensor of softmax. "
......
...@@ -20,8 +20,7 @@ namespace operators { ...@@ -20,8 +20,7 @@ namespace operators {
class SoftmaxWithCrossEntropyOpMaker class SoftmaxWithCrossEntropyOpMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto, SoftmaxWithCrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits", AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities " "(Tensor, default: Tensor<float>), The unscaled log probabilities "
......
...@@ -118,8 +118,7 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -118,8 +118,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SplitLoDTensorOpProtoMaker(framework::OpProto *proto, SplitLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input LoDTensor"); AddInput("X", "The input LoDTensor");
AddInput("Mask", "A bool column vector which mask the input"); AddInput("Mask", "A bool column vector which mask the input");
......
...@@ -65,7 +65,7 @@ class SplitOp : public framework::OperatorWithKernel { ...@@ -65,7 +65,7 @@ class SplitOp : public framework::OperatorWithKernel {
class SplitOpMaker : public framework::OpProtoAndCheckerMaker { class SplitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SplitOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SplitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of the split operator."); AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the split operator.") AddOutput("Out", "(Tensor) Output tensors of the split operator.")
......
...@@ -18,7 +18,7 @@ namespace operators { ...@@ -18,7 +18,7 @@ namespace operators {
class SppOpMaker : public framework::OpProtoAndCheckerMaker { class SppOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SppOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) SppOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
......
...@@ -56,8 +56,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { ...@@ -56,8 +56,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker { class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SquaredL2DistanceOpMaker(framework::OpProto* proto, SquaredL2DistanceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input of SquaredL2DistanceOp."); AddInput("X", "(Tensor) Input of SquaredL2DistanceOp.");
AddInput("Y", "(Tensor) Target of SquaredL2DistanceOp."); AddInput("Y", "(Tensor) Target of SquaredL2DistanceOp.");
......
...@@ -48,8 +48,7 @@ class SquaredL2NormGradOp : public framework::OperatorWithKernel { ...@@ -48,8 +48,7 @@ class SquaredL2NormGradOp : public framework::OperatorWithKernel {
class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker { class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SquaredL2NormOpMaker(framework::OpProto* proto, SquaredL2NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of squared_l2_norm op."); AddInput("X", "(Tensor) The input of squared_l2_norm op.");
AddOutput("Out", "(Scalar) The output of squared_l2_norm op."); AddOutput("Out", "(Scalar) The output of squared_l2_norm op.");
......
...@@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel {
"Output(Out) of SumOp should not be null."); "Output(Out) of SumOp should not be null.");
if (ctx->IsRuntime() && if (ctx->IsRuntime() &&
ctx->GetOutputsVarType("Out")[0] == ctx->GetOutputsVarType("Out")[0] ==
framework::VarDesc::LOD_TENSOR_ARRAY) { framework::proto::VarDesc::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array; return; // skip runtime infershape when is tensor array;
} }
...@@ -72,8 +72,8 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -72,8 +72,8 @@ class SumOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NE(dtype, -1, PADDLE_ENFORCE_NE(dtype, -1,
"Sum operator should have at least one tensor"); "Sum operator should have at least one tensor");
return framework::OpKernelType(static_cast<framework::DataType>(dtype), return framework::OpKernelType(
ctx.device_context()); static_cast<framework::proto::DataType>(dtype), ctx.device_context());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) { } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( framework::ToDataType(
...@@ -98,7 +98,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -98,7 +98,7 @@ class SumOp : public framework::OperatorWithKernel {
class SumOpMaker : public framework::OpProtoAndCheckerMaker { class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) SumOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.") AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable(); .AsDuplicable();
...@@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDescBind& op_desc, void operator()(const framework::OpDescBind& op_desc,
framework::BlockDescBind* block) const override { framework::BlockDescBind* block) const override {
auto& inputs = op_desc.Input("X"); auto& inputs = op_desc.Input("X");
auto var_type = framework::VarDesc::SELECTED_ROWS; auto var_type = framework::proto::VarDesc::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) { for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " " VLOG(10) << name << " "
...@@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) { inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name)->GetType() == return block->FindRecursiveOrCreateVar(name)->GetType() ==
framework::VarDesc::LOD_TENSOR; framework::proto::VarDesc::LOD_TENSOR;
}); });
auto is_tensor_array = [block](const std::string& name) { auto is_tensor_array = [block](const std::string& name) {
return detail::Ref(block->FindRecursiveOrCreateVar(name)).GetType() == return detail::Ref(block->FindRecursiveOrCreateVar(name)).GetType() ==
framework::VarDesc::LOD_TENSOR_ARRAY; framework::proto::VarDesc::LOD_TENSOR_ARRAY;
}; };
bool any_input_is_tensor_array = bool any_input_is_tensor_array =
...@@ -152,9 +152,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -152,9 +152,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
PADDLE_ENFORCE(all_inputs_are_tensor_array, PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str()); "Not all inputs are tensor array:\n%s", os.str());
} }
var_type = framework::VarDesc::LOD_TENSOR_ARRAY; var_type = framework::proto::VarDesc::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) { } else if (any_input_is_lod_tensor) {
var_type = framework::VarDesc::LOD_TENSOR; var_type = framework::proto::VarDesc::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
......
...@@ -51,8 +51,7 @@ class WriteToArrayOp : public ArrayOp { ...@@ -51,8 +51,7 @@ class WriteToArrayOp : public ArrayOp {
class WriteToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker { class WriteToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
WriteToArrayOpProtoMaker(framework::OpProto *proto, WriteToArrayOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) the tensor will be written to tensor array"); AddInput("X", "(LoDTensor) the tensor will be written to tensor array");
AddInput( AddInput(
...@@ -104,7 +103,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { ...@@ -104,7 +103,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = detail::Ref(block->FindRecursiveOrCreateVar(out_name), auto &out = detail::Ref(block->FindRecursiveOrCreateVar(out_name),
"Cannot found %s", out_name); "Cannot found %s", out_name);
out.SetType(framework::VarDesc::LOD_TENSOR_ARRAY); out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name); auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) { if (x != nullptr) {
out.SetDataType(x->GetDataType()); out.SetDataType(x->GetDataType());
...@@ -140,8 +139,7 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -140,8 +139,7 @@ class ReadFromArrayOp : public ArrayOp {
class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker { class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ReadFromArrayProtoMaker(framework::OpProto *proto, ReadFromArrayProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(TensorArray) the array will be read from."); AddInput("X", "(TensorArray) the array will be read from.");
AddInput("I", AddInput("I",
......
...@@ -46,7 +46,7 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -46,7 +46,7 @@ class TopkOp : public framework::OperatorWithKernel {
class TopkOpMaker : public framework::OpProtoAndCheckerMaker { class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
TopkOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) TopkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of Topk op"); AddInput("X", "(Tensor) The input of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op"); AddOutput("Out", "(Tensor) The output tensor of Topk op");
......
...@@ -55,8 +55,7 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -55,8 +55,7 @@ class TransposeOp : public framework::OperatorWithKernel {
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
TransposeOpMaker(framework::OpProto* proto, TransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
......
...@@ -66,15 +66,14 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -66,15 +66,14 @@ class UniformRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType( framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
UniformRandomOpMaker(framework::OpProto* proto, UniformRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) The output tensor of uniform random op"); AddOutput("Out", "(Tensor) The output tensor of uniform random op");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -100,7 +99,7 @@ uniform distribution. ...@@ -100,7 +99,7 @@ uniform distribution.
"0 means use a seed generated by the system.") "0 means use a seed generated by the system.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type") AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -18,8 +18,7 @@ namespace operators { ...@@ -18,8 +18,7 @@ namespace operators {
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Unpool2dOpMaker(framework::OpProto* proto, Unpool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
......
...@@ -64,7 +64,7 @@ class WhileOp : public framework::OperatorBase { ...@@ -64,7 +64,7 @@ class WhileOp : public framework::OperatorBase {
class WhileOpMaker : public framework::OpProtoAndCheckerMaker { class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
WhileOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) WhileOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kParameters, AddInput(kParameters,
"A set of variables, which are required by operators inside the " "A set of variables, which are required by operators inside the "
...@@ -321,10 +321,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ...@@ -321,10 +321,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
continue; continue;
} }
auto dims = ctx->GetInputsElementDim(kParameters, i); auto dims = ctx->GetInputsElementDim(kParameters, i);
if (var_types[i] == framework::VarDesc::LOD_TENSOR) { if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]); names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims); dims_to_set.push_back(dims);
} else if (var_types[i] == framework::VarDesc::LOD_TENSOR_ARRAY) { } else if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR_ARRAY) {
// not sure how to set the dim of LOD_TENSOR_ARRAY // not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_names[i]); names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims); dims_to_set.push_back(dims);
......
...@@ -31,31 +31,32 @@ std::string Escape(const std::string& s) { ...@@ -31,31 +31,32 @@ std::string Escape(const std::string& s) {
return r; return r;
} }
std::string AttrType(paddle::framework::AttrType at) { std::string AttrType(paddle::framework::proto::AttrType at) {
switch (at) { switch (at) {
case paddle::framework::INT: case paddle::framework::proto::INT:
return "int"; return "int";
case paddle::framework::FLOAT: case paddle::framework::proto::FLOAT:
return "float"; return "float";
case paddle::framework::STRING: case paddle::framework::proto::STRING:
return "string"; return "string";
case paddle::framework::BOOLEAN: case paddle::framework::proto::BOOLEAN:
return "bool"; return "bool";
case paddle::framework::INTS: case paddle::framework::proto::INTS:
return "int array"; return "int array";
case paddle::framework::FLOATS: case paddle::framework::proto::FLOATS:
return "float array"; return "float array";
case paddle::framework::STRINGS: case paddle::framework::proto::STRINGS:
return "string array"; return "string array";
case paddle::framework::BOOLEANS: case paddle::framework::proto::BOOLEANS:
return "bool array"; return "bool array";
case paddle::framework::BLOCK: case paddle::framework::proto::BLOCK:
return "block id"; return "block id";
} }
return "UNKNOWN"; // not possible return "UNKNOWN"; // not possible
} }
void PrintVar(const paddle::framework::OpProto::Var& v, std::stringstream& ss) { void PrintVar(const paddle::framework::proto::OpProto::Var& v,
std::stringstream& ss) {
ss << " { " ss << " { "
<< "\n" << "\n"
<< " \"name\" : \"" << Escape(v.name()) << "\",\n" << " \"name\" : \"" << Escape(v.name()) << "\",\n"
...@@ -65,7 +66,7 @@ void PrintVar(const paddle::framework::OpProto::Var& v, std::stringstream& ss) { ...@@ -65,7 +66,7 @@ void PrintVar(const paddle::framework::OpProto::Var& v, std::stringstream& ss) {
<< " },"; << " },";
} }
void PrintAttr(const paddle::framework::OpProto::Attr& a, void PrintAttr(const paddle::framework::proto::OpProto::Attr& a,
std::stringstream& ss) { std::stringstream& ss) {
ss << " { " ss << " { "
<< "\n" << "\n"
...@@ -81,7 +82,7 @@ void PrintOpProto(const std::string& type, ...@@ -81,7 +82,7 @@ void PrintOpProto(const std::string& type,
std::stringstream& ss) { std::stringstream& ss) {
std::cerr << "Processing " << type << "\n"; std::cerr << "Processing " << type << "\n";
const paddle::framework::OpProto* p = opinfo.proto_; const paddle::framework::proto::OpProto* p = opinfo.proto_;
if (p == nullptr) { if (p == nullptr) {
return; // It is possible that an operator doesn't have OpProto. return; // It is possible that an operator doesn't have OpProto.
} }
......
...@@ -144,7 +144,7 @@ void BindProgramDesc(py::module &m) { ...@@ -144,7 +144,7 @@ void BindProgramDesc(py::module &m) {
.def("serialize_to_string", SerializeMessage<ProgramDescBind>) .def("serialize_to_string", SerializeMessage<ProgramDescBind>)
.def("parse_from_string", .def("parse_from_string",
[](ProgramDescBind &program_desc, const std::string &data) { [](ProgramDescBind &program_desc, const std::string &data) {
ProgramDesc *desc = program_desc.Proto(); proto::ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->ParseFromString(data), PADDLE_ENFORCE(desc->ParseFromString(data),
"Fail to parse ProgramDesc from string. This could " "Fail to parse ProgramDesc from string. This could "
"be a bug of Paddle."); "be a bug of Paddle.");
...@@ -184,14 +184,14 @@ void BindBlockDesc(py::module &m) { ...@@ -184,14 +184,14 @@ void BindBlockDesc(py::module &m) {
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
py::enum_<DataType>(m, "DataType", "") py::enum_<proto::DataType>(m, "DataType", "")
.value("BOOL", DataType::BOOL) .value("BOOL", proto::DataType::BOOL)
.value("INT16", DataType::INT16) .value("INT16", proto::DataType::INT16)
.value("INT32", DataType::INT32) .value("INT32", proto::DataType::INT32)
.value("INT64", DataType::INT64) .value("INT64", proto::DataType::INT64)
.value("FP16", DataType::FP16) .value("FP16", proto::DataType::FP16)
.value("FP32", DataType::FP32) .value("FP32", proto::DataType::FP32)
.value("FP64", DataType::FP64); .value("FP64", proto::DataType::FP64);
py::class_<VarDescBind> var_desc(m, "VarDesc", ""); py::class_<VarDescBind> var_desc(m, "VarDesc", "");
var_desc var_desc
...@@ -213,27 +213,27 @@ void BindVarDsec(py::module &m) { ...@@ -213,27 +213,27 @@ void BindVarDsec(py::module &m) {
.def("persistable", &VarDescBind::Persistable) .def("persistable", &VarDescBind::Persistable)
.def("set_persistable", &VarDescBind::SetPersistable); .def("set_persistable", &VarDescBind::SetPersistable);
py::enum_<VarDesc::VarType>(var_desc, "VarType", "") py::enum_<proto::VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", VarDesc::LOD_TENSOR) .value("LOD_TENSOR", proto::VarDesc::LOD_TENSOR)
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS) .value("SELECTED_ROWS", proto::VarDesc::SELECTED_ROWS)
.value("FEED_MINIBATCH", VarDesc::FEED_MINIBATCH) .value("FEED_MINIBATCH", proto::VarDesc::FEED_MINIBATCH)
.value("FETCH_LIST", VarDesc::FETCH_LIST) .value("FETCH_LIST", proto::VarDesc::FETCH_LIST)
.value("STEP_SCOPES", VarDesc::STEP_SCOPES) .value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", VarDesc::LOD_RANK_TABLE) .value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", VarDesc::LOD_TENSOR_ARRAY); .value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
py::enum_<AttrType>(m, "AttrType", "") py::enum_<proto::AttrType>(m, "AttrType", "")
.value("INT", AttrType::INT) .value("INT", proto::AttrType::INT)
.value("INTS", AttrType::INTS) .value("INTS", proto::AttrType::INTS)
.value("FLOAT", AttrType::FLOAT) .value("FLOAT", proto::AttrType::FLOAT)
.value("FLOATS", AttrType::FLOATS) .value("FLOATS", proto::AttrType::FLOATS)
.value("STRING", AttrType::STRING) .value("STRING", proto::AttrType::STRING)
.value("STRINGS", AttrType::STRINGS) .value("STRINGS", proto::AttrType::STRINGS)
.value("BOOL", AttrType::BOOLEAN) .value("BOOL", proto::AttrType::BOOLEAN)
.value("BOOLS", AttrType::BOOLEANS) .value("BOOLS", proto::AttrType::BOOLEANS)
.value("BLOCK", AttrType::BLOCK); .value("BLOCK", proto::AttrType::BLOCK);
py::class_<OpDescBind> op_desc(m, "OpDesc", ""); py::class_<OpDescBind> op_desc(m, "OpDesc", "");
op_desc.def("type", &OpDescBind::Type) op_desc.def("type", &OpDescBind::Type)
......
...@@ -288,12 +288,12 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -288,12 +288,12 @@ All parameter, weight, gradient are variables in Paddle.
for (const auto &t : targets) { for (const auto &t : targets) {
prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget(); prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
} }
ProgramDesc pruned_desc; proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc); Prune(*prog_with_targets.Proto(), &pruned_desc);
return new ProgramDescBind(pruned_desc); return new ProgramDescBind(pruned_desc);
}); });
m.def("inference_optimize", [](ProgramDescBind &origin) { m.def("inference_optimize", [](ProgramDescBind &origin) {
ProgramDesc pruned_desc; proto::ProgramDesc pruned_desc;
InferenceOptimize(*(origin.Proto()), &pruned_desc); InferenceOptimize(*(origin.Proto()), &pruned_desc);
return new ProgramDescBind(pruned_desc); return new ProgramDescBind(pruned_desc);
}); });
...@@ -345,7 +345,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -345,7 +345,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<OperatorBase>(m, "Operator") py::class_<OperatorBase>(m, "Operator")
.def_static("create", .def_static("create",
[](py::bytes protobin) { [](py::bytes protobin) {
OpDesc desc; proto::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc"); "Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
...@@ -398,7 +398,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -398,7 +398,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<operators::CondOp, OperatorBase>(m, "CondOp") py::class_<operators::CondOp, OperatorBase>(m, "CondOp")
.def_static("create", .def_static("create",
[](py::bytes protobin) -> operators::CondOp * { [](py::bytes protobin) -> operators::CondOp * {
OpDesc desc; proto::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc"); "Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
......
...@@ -14,9 +14,8 @@ make -j `nproc` print_operators_doc ...@@ -14,9 +14,8 @@ make -j `nproc` print_operators_doc
paddle/pybind/print_operators_doc > doc/en/html/operators.json paddle/pybind/print_operators_doc > doc/en/html/operators.json
# check websites for broken links # check websites for broken links
# It will be failed now! linkchecker doc/en/html/index.html
#linkchecker doc/en/html/index.html linkchecker doc/cn/html/index.html
#linkchecker doc/cn/html/index.html
# Parse Github URL # Parse Github URL
REPO=`git config remote.origin.url` REPO=`git config remote.origin.url`
......
...@@ -25,10 +25,10 @@ from paddle.trainer.config_parser import * ...@@ -25,10 +25,10 @@ from paddle.trainer.config_parser import *
__all__ = [ __all__ = [
'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool", 'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool",
"img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg', "img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg',
'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru', 'img_conv_group', 'img_separable_conv', 'vgg_16_network', 'gru_unit',
'simple_attention', 'dot_product_attention', 'multi_head_attention', 'gru_group', 'simple_gru', 'simple_attention', 'dot_product_attention',
'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'multi_head_attention', 'simple_gru2', 'bidirectional_gru',
'inputs', 'outputs' 'text_conv_pool', 'bidirectional_lstm', 'inputs', 'outputs'
] ]
###################################################### ######################################################
...@@ -251,13 +251,13 @@ def img_conv_bn_pool(input, ...@@ -251,13 +251,13 @@ def img_conv_bn_pool(input,
pool_layer_attr=None): pool_layer_attr=None):
""" """
Convolution, batch normalization, pooling group. Convolution, batch normalization, pooling group.
Img input => Conv => BN => Pooling => Output. Img input => Conv => BN => Pooling => Output.
:param name: group name. :param name: group name.
:type name: basestring :type name: basestring
:param input: input layer. :param input: input layer.
:type input: LayerOutput :type input: LayerOutput
:param filter_size: see img_conv_layer for details. :param filter_size: see img_conv_layer for details.
:type filter_size: int :type filter_size: int
:param num_filters: see img_conv_layer for details. :param num_filters: see img_conv_layer for details.
...@@ -435,6 +435,85 @@ def img_conv_group(input, ...@@ -435,6 +435,85 @@ def img_conv_group(input,
input=tmp, stride=pool_stride, pool_size=pool_size, pool_type=pool_type) input=tmp, stride=pool_stride, pool_size=pool_size, pool_type=pool_type)
@wrap_name_default("separable_conv")
def img_separable_conv(input,
num_channels,
num_out_channels,
filter_size,
stride=1,
padding=0,
depth_multiplier=1,
act=None,
bias_attr=None,
param_attr=None,
shared_bias=True,
layer_type='exconv',
name=None):
"""
Separable Convolution.
The separable convolution module is consisted of a depthwise convolution
that acts separately on input channels, followed by a pointwise convolution
with 1*1 kernels that mixes channels. It is used for Xception:
https://arxiv.org/pdf/1610.02357.pdf
:param input: input layer.
:type input: LayerOutput
:param num_channels: the number of input channels.
:type num_channels: int
:param num_out_channels: the number of output channels.
:type num_out_channels: int
:param filter_size: the filter size for the depthwise convolution.
:type filter_size: int|tuple
:param stride: the stride size for the depthwise convolution.
:type stride: int|tuple
:param padding: the padding size for the depthwise convolution.
:type padding: int|tuple
:param depth_multiplier: the number of filter for one channel in the
depthwize convolution.
:type depth_multiplier: int
:param act: the activation function for the output.
:type act: BaseActivation
:param bias_attr: see img_conv_layer for details.
:type bias_attr: ParameterAttribute
:param param_attr: see img_conv_layer for details.
:type param_attr: ParameterAttribute
:param shared_bias: see img_conv_layer for details.
:type shared_bias: bool
:param layer_type: see img_conv_layer for details.
:type layer_type: bool
:return: layer's output
:rtype: LayerOutput
"""
__depthwise_conv__ = img_conv_layer(
name="%s_depthwise_conv" % name,
input=input,
num_channels=num_channels,
num_filters=num_channels * depth_multiplier,
groups=num_channels,
filter_size=filter_size,
stride=stride,
padding=padding,
act=LinearActivation(),
bias_attr=bias_attr,
param_attr=param_attr,
shared_biases=shared_bias,
layer_type=layer_type)
__pointwise_conv__ = img_conv_layer(
name="%s_pointwise_conv" % name,
input=__depthwise_conv__,
num_channels=num_channels * depth_multiplier,
num_filters=num_out_channels,
filter_size=1,
stride=1,
padding=0,
act=act,
bias_attr=bias_attr,
param_attr=param_attr,
shared_biases=shared_bias)
return __pointwise_conv__
def small_vgg(input_image, num_channels, num_classes): def small_vgg(input_image, num_channels, num_classes):
def __vgg__(ipt, num_filter, times, dropouts, num_channels_=None): def __vgg__(ipt, num_filter, times, dropouts, num_channels_=None):
return img_conv_group( return img_conv_group(
...@@ -648,7 +727,7 @@ def lstmemory_unit(input, ...@@ -648,7 +727,7 @@ def lstmemory_unit(input,
lstm_bias_attr=None, lstm_bias_attr=None,
lstm_layer_attr=None): lstm_layer_attr=None):
""" """
lstmemory_unit defines the caculation process of a LSTM unit during a lstmemory_unit defines the caculation process of a LSTM unit during a
single time step. This function is not a recurrent layer, so it can not be single time step. This function is not a recurrent layer, so it can not be
directly used to process sequence input. This function is always used in directly used to process sequence input. This function is always used in
recurrent_group (see layers.py for more details) to implement attention recurrent_group (see layers.py for more details) to implement attention
...@@ -869,7 +948,7 @@ def gru_unit(input, ...@@ -869,7 +948,7 @@ def gru_unit(input,
gru_layer_attr=None, gru_layer_attr=None,
naive=False): naive=False):
""" """
gru_unit defines the calculation process of a gated recurrent unit during a single gru_unit defines the calculation process of a gated recurrent unit during a single
time step. This function is not a recurrent layer, so it can not be time step. This function is not a recurrent layer, so it can not be
directly used to process sequence input. This function is always used in directly used to process sequence input. This function is always used in
the recurrent_group (see layers.py for more details) to implement attention the recurrent_group (see layers.py for more details) to implement attention
...@@ -1012,7 +1091,7 @@ def simple_gru(input, ...@@ -1012,7 +1091,7 @@ def simple_gru(input,
simple_gru in network.py. The reason why there are so many interfaces is simple_gru in network.py. The reason why there are so many interfaces is
that we have two ways to implement recurrent neural network. One way is to that we have two ways to implement recurrent neural network. One way is to
use one complete layer to implement rnn (including simple rnn, gru and lstm) use one complete layer to implement rnn (including simple rnn, gru and lstm)
with multiple time steps, such as recurrent_layer, lstmemory, grumemory. But with multiple time steps, such as recurrent_layer, lstmemory, grumemory. But
the multiplication operation :math:`W x_t` is not computed in these layers. the multiplication operation :math:`W x_t` is not computed in these layers.
See details in their interfaces in layers.py. See details in their interfaces in layers.py.
The other implementation is to use an recurrent group which can ensemble a The other implementation is to use an recurrent group which can ensemble a
...@@ -1116,7 +1195,7 @@ def simple_gru2(input, ...@@ -1116,7 +1195,7 @@ def simple_gru2(input,
:type act: BaseActivation :type act: BaseActivation
:param gate_act: gate activiation type of gru :param gate_act: gate activiation type of gru
:type gate_act: BaseActivation :type gate_act: BaseActivation
:param gru_bias_attr: bias parameter attribute of gru layer, :param gru_bias_attr: bias parameter attribute of gru layer,
False means no bias, None means default bias. False means no bias, None means default bias.
:type gru_bias_attr: ParameterAttribute|False|None :type gru_bias_attr: ParameterAttribute|False|None
:param gru_param_attr: param parameter attribute of gru layer, :param gru_param_attr: param parameter attribute of gru layer,
...@@ -1189,7 +1268,7 @@ def bidirectional_gru(input, ...@@ -1189,7 +1268,7 @@ def bidirectional_gru(input,
:type size: int :type size: int
:param return_seq: If set False, the last time step of output are :param return_seq: If set False, the last time step of output are
concatenated and returned. concatenated and returned.
If set True, the entire output sequences in forward If set True, the entire output sequences in forward
and backward directions are concatenated and returned. and backward directions are concatenated and returned.
:type return_seq: bool :type return_seq: bool
:return: LayerOutput object. :return: LayerOutput object.
...@@ -1278,7 +1357,7 @@ def bidirectional_lstm(input, ...@@ -1278,7 +1357,7 @@ def bidirectional_lstm(input,
:type size: int :type size: int
:param return_seq: If set False, the last time step of output are :param return_seq: If set False, the last time step of output are
concatenated and returned. concatenated and returned.
If set True, the entire output sequences in forward If set True, the entire output sequences in forward
and backward directions are concatenated and returned. and backward directions are concatenated and returned.
:type return_seq: bool :type return_seq: bool
:return: LayerOutput object. :return: LayerOutput object.
......
...@@ -13,7 +13,7 @@ __all__ = [ ...@@ -13,7 +13,7 @@ __all__ = [
'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy', 'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy',
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', 'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
'lstm_unit' 'lstm_unit', 'reduce_sum'
] ]
...@@ -402,8 +402,8 @@ def chunk_eval(input, ...@@ -402,8 +402,8 @@ def chunk_eval(input,
}, },
attrs={ attrs={
"num_chunk_types": num_chunk_types, "num_chunk_types": num_chunk_types,
'chunk_scheme': chunk_scheme, "chunk_scheme": chunk_scheme,
'excluded_chunk_types': excluded_chunk_types or [] "excluded_chunk_types": excluded_chunk_types or []
}) })
return precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks return precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks
...@@ -935,3 +935,47 @@ def lstm_unit(x_t, ...@@ -935,3 +935,47 @@ def lstm_unit(x_t,
attrs={"forget_bias": forget_bias}) attrs={"forget_bias": forget_bias})
return h, c return h, c
def reduce_sum(input, dim=None, keep_dim=False):
"""
Computes the sum of tensor elements over the given dimension.
Args:
input (Variable): The input variable which is a Tensor or LoDTensor.
dim (int|None): The dimension along which the sum is performed. If
:attr:`None`, sum all elements of :attr:`input` and return a
Tensor variable with a single element, otherwise must be in the
range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`,
the dimension to reduce is :math:`rank + dim`.
keep_dim (bool): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true.
Returns:
Variable: The reduced Tensor variable.
Examples:
.. code-block:: python
# x is a Tensor variable with following elements:
# [[0.2, 0.3, 0.5, 0.9]
# [0.1, 0.2, 0.6, 0.7]]
# Each example is followed by the correspending output tensor.
fluid.layers.reduce_sum(x) # [3.5]
fluid.layers.reduce_sum(x, dim=0) # [0.3, 0.5, 1.1, 1.6]
fluid.layers.reduce_sum(x, dim=-1) # [1.9, 1.6]
fluid.layers.reduce_sum(x, dim=1, keep_dim=True) # [[1.9], [1.6]]
"""
helper = LayerHelper('reduce_sum', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='reduce_sum',
inputs={'X': input},
outputs={'Out': out},
attrs={
'dim': dim if dim != None else 0,
'keep_dim': keep_dim,
'reduce_all': True if dim == None else False
})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册