提交 86ed4ef6 编写于 作者: L Luo Tao

Merge branch 'develop' into mklml

......@@ -9,7 +9,7 @@ import subprocess
import platform
COPYRIGHT = '''
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
# Embed Paddle Inference in Your Application
Paddle inference offers the APIs in `C` and `C++` languages.
One can easily deploy a model trained by Paddle following the steps as below:
1. Optimize the native model;
2. Write some codes for deployment.
Let's explain the steps in detail.
## Optimize the native Fluid Model
The native model that get from the training phase needs to be optimized for that.
- Clean the noise such as the cost operators that do not need inference;
- Prune unnecessary computation fork that has nothing to do with the output;
- Remove extraneous variables;
- Memory reuse for native Fluid executor;
- Translate the model storage format to some third-party engine's, so that the inference API can utilize the engine for acceleration;
We have an official tool to do the optimization, call `paddle_inference_optimize --help` for more information.
## Write some codes
Read `paddle_inference_api.h` for more information.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
namespace paddle {
class Predictor {
public:
struct Attr;
Predictor() = default;
// Build the network before inference.
bool Init(const Attr& attr);
// Predict an record.
// Arguments:
// inputs: the name of the input variables.
// outputs: the name of the output varaibles.
// input_shapes: the shape of the input variables.
// output_shapes: the shape of the output variables.
// input_data: the data of the input variables.
// output_data: the data of the output variables.
bool Run(const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::vector<int>>& input_shapes,
const std::vector<std::vector<int>>& output_shapes,
const std::vector<std::vector<float>>& input_data,
std::vector<std::vector<float>>* output_data);
// Clone a predictor that share the model weights.
Predictor* Clone();
// Destroy the Predictor.
~Predictor();
struct Attr {
enum class EngineKind;
std::string model_dir; // path to the model directory.
bool enable_engine{false}; // Enable to execute (part of) the model on
// third-party engines.
EngineKind engine_kind{Attr::EngineKind::kNone};
enum class EngineKind {
kNone = -1, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference.
kTensorRT, // Use TensorRT for inference.
kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
};
};
};
} // namespace paddle
......@@ -77,8 +77,7 @@ print "The sematic-vector of testA: ", paddle.infer(fA, parameters, testA)
### Example 2. Sharing Parameters between "Models"
We use [GAN](https://github.com/PaddlePaddle/book/tree/develop/gan) in
this example. In the following example program, `d0` and `d1`
We use GAN in this example. In the following example program, `d0` and `d1`
correspond to the two networks in the following figure:
<img src="https://github.com/wangyang59/book/raw/00036f4b0da5225041a6824587c1a01cf20159b1/gan/image/gan_ig.png" width=400 />
......
......@@ -75,7 +75,7 @@ Different layout leads to different implementation of the operator kernel. There
- The inference of Layout is at run-time, not at compile-time.
- Every operator has to implement different kernels for different layouts. Let's take MKLDNN as an example. If we want to implement an MKLDNN convolution operator, we have to implement all the kernels for different layouts, which are listed [here](http://01org.github.io/mkl-dnn/structmkldnn_1_1memory.html). And we will have a special macro to register kernels for MKLDNN operators.
- Every operator has to implement different kernels for different layouts. Let's take MKLDNN as an example. If we want to implement an MKLDNN convolution operator, we have to implement all the kernels for different layouts, which are listed [here](http://intel.github.io/mkl-dnn/structmkldnn_1_1memory.html). And we will have a special macro to register kernels for MKLDNN operators.
`Layout` is also defined as a enum variable:
......
# Distributed Training with NCCL2 and RDMA
When doing distributed multi-GPU training, network bandwith often becomes the
bottle neck. We introduce a way to use NCCL2 to do such training job to
achieve best performace.
## Prepare Hardwares with RDMA and Multiple GPUs
I'm using two Linux servers each of them is installed with 8 GPUs and
one 100Gb RDMA card.
Base environment is:
* OS: CentOS 7.4
* RDMA device: "Mellanox Technologies MT27700 Family [ConnectX-4]"
* Kernel version: `4.4.88-1.el7.elrepo.x86_64`
* Docker version: `1.12.6`
* Docker storage driver: `overlay2`
* IP addresses: 192.168.16.30,192.168.16.34
In general, the steps including:
1. Install GPU drivers
1. Install RDMA drivers
1. Install "InfiniBand Support"
1. Use docker to run tests and make sure GPUs and RDMA can work inside
the container.
I'll ommit section "Install GPU drivers" because we can find it easily
somewhere else.
### Install RDMA drivers
For my case, I've got two machines with device
"Mellanox Technologies MT27700 Family [ConnectX-4]" installed. The OS was
"CentOS 7.4" and I updated the kernel to version 4.4 so that docker can
work with latest overlay2 filesystem.
***NOTE: before you start, make sure you have a way to get a console
of the server other than ssh because we may need to re-configure the
network device.***
1. Go to http://www.mellanox.com/page/products_dyn?product_family=26,
download `MLNX_OFED` software in the bottom of the page, and upload it
onto the server.
1. Run `./mlnxofedinstall --add-kernel-support` in the software package.
1. Run `/etc/init.d/openibd restart` to make everything work, note that
this operation may cause the network goes down if you are using this
RDMA device as default network device and use ssh to login the server.
1. Re-configure the network interface, for example:
`ifconfig eth2 192.168.16.30/20 up`, then add routes if needed:
`ip route add default via 192.168.16.1 dev eth2`.
1. Do the same thing on the other node.
1. Use `ping` to test if the two nodes have typical ICMP connection.
1. Use either `udaddy` or `ib_write_bw` to test the network connection is
ready and have the desired bandwith.
### Prepare Docker Image to Run RDMA Programs
1. Build a docker image using cuda base image like: `nvidia/cuda:8.0-cudnn5-devel-ubuntu16.04` and install paddlepaddle whl
package in it.
1. Start a docker container and mount GPU driver libs into it (you can
skip this step if you are using nvidia-docker).
1. Mount RDMA dirvers and libs into the docker image (see below section),
also `udaddy` and `ib_write_bw` if needed.
1. Mount GPU devices and RDMA devices into the container using `--device`
or just use privileged mode `--privileged`.
1. Start the container using host network mode: `--net=host`
### RDMA Library Files Needed
Usually, `MLNX_OFED` install latest supported libs under
`/usr/lib64/mlnx_ofed/valgrind`. Other libs also needed to run RDMA programs
is listed below. These libs must be mounted into the docker container.
* Libs under `/usr/lib64/mlnx_ofed/valgrind`
* libibcm.so
* libibverbs.so
* libmlx4.so
* libmlx5.so
* libmlx5-rdmav2.so
* librdmacm.so
* Other libs:
* libnl-3.so.200
* libnl-route-3.so.200
* libnuma.so.1
## Start to Run the Training Job
Setting NCCL environment variables to turn NCCL switches on and off:
| Env Name | Description |
| --- | --- |
| NCCL_SOCKET_IFNAME | The RDMA device, e.g. eth2 |
| NCCL_P2P_DISABLE | Set to 1 to disable P2P transfer between GPUs |
| NCCL_IB_DISABLE | Set to 1 to disable using RDMA |
| NCCL_IB_CUDA_SUPPORT | Set to 1 to enable GPU Direct if supported |
| NCCL_DEBUG | Set debug level: VERSION, WARN, INFO |
My two servers are: `192.168.16.30,192.168.16.34`, On node 1, Run :
```bash
PADDLE_TRAINER_ID=0 PADDLE_PORT=48372 PADDLE_WORKERS=192.168.16.30,192.168.16.34 POD_IP=192.168.16.30 stdbuf -oL python vgg16.py
```
On node 2, Run:
```bash
PADDLE_TRAINER_ID=1 PADDLE_PORT=48372 PADDLE_WORKERS=192.168.16.30,192.168.16.34 POD_IP=192.168.16.34 stdbuf -oL python vgg16.py
```
......@@ -57,7 +57,7 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
cc_library(attribute SRCS attribute.cc DEPS framework_proto boost)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
......
......@@ -134,6 +134,11 @@ OpDesc *BlockDesc::PrependOp() {
return ops_.front().get();
}
void BlockDesc::PrependAllocatedOp(std::unique_ptr<OpDesc> &&op_desc) {
need_update_ = true;
ops_.emplace_front(std::move(op_desc));
}
OpDesc *BlockDesc::InsertOp(size_t index) {
need_update_ = true;
auto it = ops_.begin() + index;
......
......@@ -88,6 +88,8 @@ class BlockDesc {
OpDesc *PrependOp();
void PrependAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
OpDesc *InsertOp(size_t index);
/*
......
......@@ -32,8 +32,7 @@ struct AddFunctor {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input1 of test op");
AddOutput("output", "output of test op");
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
......
......@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() {
out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");
// Wait input done, this Wait is asynchronous operation platform::Place
// &in_place;
WaitInputVarGenerated(*in_var_handle);
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
......@@ -50,29 +48,9 @@ void BroadcastOpHandle::RunImpl() {
auto *in_var =
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
// NOTE: The tensors' Place of input and output must be all on GPU or all on
// CPU.
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
continue;
}
auto t_out_p = out_var_handle->place_;
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var);
if (platform::is_gpu_place(in_tensor.place())) {
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
"Places of input and output must be all on GPU.");
} else {
t_out_p = platform::CPUPlace();
}
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
in_tensor.type());
}
InitOutputValue(*in_var_handle, out_var_handles);
if (platform::is_cpu_place(in_tensor.place())) {
for (auto *out_var_handle : out_var_handles) {
......@@ -147,11 +125,37 @@ void BroadcastOpHandle::RunImpl() {
}
}
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
if (in_var.generated_op_) {
for (auto &pair : dev_ctxes_) {
in_var.generated_op_->Wait(pair.second);
void BroadcastOpHandle::InitOutputValue(
const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles) const {
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto *in_var =
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
// NOTE: The tensors' Place of input and output must be all on GPU or all on
// CPU.
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(in_var_handle)) {
continue;
}
auto t_out_p = out_var_handle->place_;
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var);
if (is_gpu_place(in_tensor.place())) {
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
"Places of input and output must be all on GPU.");
} else {
t_out_p = platform::CPUPlace();
}
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
in_tensor.type());
}
}
......
......@@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
void WaitInputVarGenerated(const VarHandle &in_var);
private:
const std::vector<Scope *> &local_scopes_;
......@@ -65,6 +64,9 @@ struct BroadcastOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_;
#endif
void InitOutputValue(const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles) const;
};
} // namespace details
} // namespace framework
......
......@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
place_(place) {}
void ComputationOpHandle::RunImpl() {
auto *cur_ctx = dev_ctxes_[place_];
for (auto *in : inputs_) {
bool need_wait = in->generated_op_ &&
in->generated_op_->DeviceContext(place_) != cur_ctx;
if (need_wait) {
in->generated_op_->Wait(cur_ctx);
}
}
WaitInputVarGenerated(place_);
this->RunAndRecordEvent([this] {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
});
}
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
in_var && in_var->generated_op_ &&
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
return need_wait;
}
std::string ComputationOpHandle::Name() const { return op_->Type(); }
} // namespace details
} // namespace framework
......
......@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
virtual bool NeedWait(VarHandleBase *in_var);
private:
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
......
......@@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() {
}
}
void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) {
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
......@@ -45,14 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
}
void FetchOpHandle::RunImpl() {
auto cpu_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
if (var->generated_op_) {
var->generated_op_->Wait(cpu_ctx);
}
}
WaitInputVarGenerated(platform::CPUPlace());
tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var_handle->name_;
......@@ -79,6 +73,15 @@ void FetchOpHandle::RunImpl() {
this->WaitAndMergeCPUTensors();
}
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
for (auto *input : inputs_) {
if (input->generated_op_) {
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx);
}
}
}
std::string FetchOpHandle::Name() const { return "Fetch"; }
} // namespace details
......
......@@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase {
~FetchOpHandle();
void Wait(platform::DeviceContext *waited_dev) override;
void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override;
void WaitAndMergeCPUTensors() const;
......@@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
virtual void WaitInputVarGenerated(const platform::Place &place);
private:
FeedFetchList *data_;
size_t offset_;
......
......@@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() {
"Currently, gather_op only can gather SelectedRows.");
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles);
WaitInputVarGenerated();
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>();
std::vector<int64_t> out_rows;
......@@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() {
});
}
void GatherOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
}
std::string GatherOpHandle::Name() const { return "gather"; }
} // namespace details
} // namespace framework
......
......@@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
private:
const std::vector<Scope *> &local_scopes_;
......
......@@ -34,12 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
return; // No need to all reduce when GPU count = 1;
} else {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
}
}
WaitInputVarGenerated();
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
int dtype = -1;
......
......@@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) {
RunImpl();
}
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctxes_) {
dev_ctx.second->Wait();
}
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
static_cast<platform::CUDADeviceContext *>(waited_ctx)->stream();
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
}
......@@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
out->generated_op_ = this;
}
void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) {
if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) {
in_var->generated_op_->RecordWaitEventOnCtx(pair.second);
}
}
}
}
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) {
if (NeedWait(in)) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]);
}
}
}
bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return in_var && in_var->generated_op_;
}
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
......
......@@ -38,12 +38,24 @@ class OpHandleBase {
void Run(bool use_event);
virtual void Wait(platform::DeviceContext *waited_dev);
virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx);
void AddInput(VarHandleBase *in);
void AddOutput(VarHandleBase *out);
// This method adds the wait events of all the input on all the device
// context.
// NODE: This Wait is asynchronous operation.
virtual void WaitInputVarGenerated();
// This method adds the wait events of all the input on the specified device
// context.
// NODE: This Wait is asynchronous operation.
virtual void WaitInputVarGenerated(const platform::Place &place);
virtual bool NeedWait(VarHandleBase *in_var);
// If the Op involves data transfer of multiple devices that
// will likely block other computations.
virtual bool IsMultiDeviceTransfer() { return false; }
......
......@@ -95,7 +95,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_);
T maker;
maker.SetProto(info->proto_);
maker.SetChecker(info->checker_);
maker.Make();
maker.Validate();
info->proto_->set_type(op_type);
PADDLE_ENFORCE(
......
......@@ -51,7 +51,7 @@ void ReduceOpHandle::RunImpl() {
PADDLE_ENFORCE_NOT_NULL(pre_in_var);
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles);
WaitInputVarGenerated();
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std::vector<platform::Place> in_places; // used to get dev_ctx
......@@ -80,19 +80,21 @@ void ReduceOpHandle::RunImpl() {
}
if (pre_in_var->IsType<framework::SelectedRows>()) {
std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
out_var->GetMutable<framework::SelectedRows>());
this->RunAndRecordEvent([&] {
std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
out_var->GetMutable<framework::SelectedRows>());
});
} else {
std::vector<const LoDTensor *> lod_tensors =
GetInputValues<LoDTensor>(in_var_handles, var_scopes);
if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) {
ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
this->RunAndRecordEvent([&] {
ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
});
} else if (paddle::platform::is_gpu_place(lod_tensors[0]->place())) {
#ifdef PADDLE_WITH_CUDA
auto pre_in = pre_in_var->Get<framework::LoDTensor>();
......@@ -157,17 +159,6 @@ std::vector<const T *> ReduceOpHandle::GetInputValues(
return in_selected_rows;
}
void ReduceOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
}
std::string ReduceOpHandle::Name() const { return "reduce"; }
} // namespace details
} // namespace framework
......
......@@ -60,8 +60,6 @@ struct ReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
template <typename T>
std::vector<const T *> GetInputValues(
const std::vector<VarHandle *> &in_var_handles,
......
......@@ -29,6 +29,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
void ScaleLossGradOpHandle::RunImpl() {
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
......@@ -26,6 +26,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
place_(place) {}
void SendOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
......@@ -33,7 +34,7 @@ void SendOpHandle::RunImpl() {
continue;
}
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]);
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
......@@ -14,8 +14,6 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
......@@ -45,73 +43,33 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops;
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
pending_vars.insert(&var);
if (var.generated_op_ == nullptr) {
ready_vars.Push(&var);
}
};
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
pending_ops.insert({&op_instance, op_instance.Inputs().size()});
};
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->vars_) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(*version_pair);
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
}
}
}
for (auto &var : graph_->dep_vars_) {
InsertPendingVar(*var);
InsertPendingVar(&pending_vars, &ready_vars, var.get());
}
for (auto &op : graph_->ops_) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
} else {
InsertPendingOp(*op);
InsertPendingOp(&pending_ops, op.get());
}
}
// Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
FeedFetchList fetch_data(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
}
}
}
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name);
auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
fetch_ops.emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
}
for (auto *var : vars) {
op->AddInput(var);
}
FeedFetchList fetch_data(fetch_tensors.size());
auto *fetch_dummy = new DummyVarHandle();
op->AddOutput(fetch_dummy);
fetch_dependencies.emplace(fetch_dummy);
InsertPendingVar(*fetch_dummy);
InsertPendingOp(*op);
}
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
&pending_vars, &ready_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
......@@ -174,6 +132,60 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return fetch_data;
}
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name);
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
}
for (auto *var : vars) {
op->AddInput(var);
}
auto *fetch_dummy = new DummyVarHandle();
op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
this->InsertPendingOp(pending_ops, op);
}
}
void ThreadedSSAGraphExecutor::InsertPendingOp(
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const {
pending_ops->insert({op_instance, op_instance->Inputs().size()});
}
void ThreadedSSAGraphExecutor::InsertPendingVar(
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
pending_vars->insert(var);
if (var->generated_op_ == nullptr) {
ready_vars->Push(var);
}
}
void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] {
......
......@@ -23,6 +23,7 @@
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
namespace paddle {
......@@ -58,6 +59,21 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_;
bool allow_op_delay_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const;
void InsertPendingVar(std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars,
VarHandleBase *var) const;
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data);
};
} // namespace details
......
......@@ -14,56 +14,57 @@ limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
namespace framework {
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
using OpProto = proto::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
virtual void Make() = 0;
virtual ~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
CHECK(validated_) << "should call Validate after build";
}
void SetProto(proto::OpProto *proto) { proto_ = proto; }
void SetChecker(OpAttrChecker *attr_checker) { op_checker_ = attr_checker; }
void Validate();
protected:
struct VariableBuilder {
OpProto::Var* var_;
proto::OpProto::Var *var_;
VariableBuilder& AsDuplicable() {
VariableBuilder &AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& AsIntermediate() {
VariableBuilder &AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
VariableBuilder& AsDispensable() {
VariableBuilder &AsDispensable() {
var_->set_dispensable(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name, const std::string& comment);
VariableBuilder AddInput(const std::string &name, const std::string &comment);
VariableBuilder AddOutput(const std::string& name,
const std::string& comment);
VariableBuilder AddOutput(const std::string &name,
const std::string &comment);
template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
TypedAttrChecker<T> &AddAttr(const std::string &name,
const std::string &comment,
bool generated = false) {
auto* attr = proto_->add_attrs();
auto *attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
......@@ -71,21 +72,14 @@ class OpProtoAndCheckerMaker {
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
void AddComment(const std::string &comment) { proto_->set_comment(comment); }
private:
void CheckNoDuplicatedInOutAttrs();
OpProto* proto_;
OpAttrChecker* op_checker_;
proto::OpProto *proto_;
OpAttrChecker *op_checker_;
bool validated_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
} // namespace framework
} // namespace paddle
......@@ -18,9 +18,7 @@ limitations under the License. */
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestAttrProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddAttr<float>("scale", "scale of test op");
AddAttr<float>("scale", "scale of test op");
}
......@@ -29,15 +27,16 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto);
proto_maker.SetChecker(&op_checker);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestInOutProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of test op");
AddInput("input", "input of test op");
}
......@@ -46,6 +45,9 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto);
proto_maker.SetChecker(&op_checker);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
......@@ -33,8 +33,7 @@ class CosineOp : public OperatorBase {
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
......@@ -55,8 +54,7 @@ class MyTestOp : public OperatorBase {
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) {
......@@ -212,10 +210,7 @@ namespace framework {
class OpKernelTestMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddComment("NoGradOp, same input output. no Grad");
}
void Make() { AddComment("NoGradOp, same input output. no Grad"); }
};
class OpWithKernelTest : public OperatorWithKernel {
......@@ -275,9 +270,9 @@ TEST(OperatorRegistrar, CUDA) {
static int op_test_value = 0;
using paddle::platform::DeviceContext;
using paddle::platform::CPUDeviceContext;
using paddle::platform::CUDADeviceContext;
using paddle::platform::DeviceContext;
namespace paddle {
namespace framework {
......
......@@ -46,8 +46,7 @@ class OpWithoutKernelTest : public OperatorBase {
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op");
......@@ -98,8 +97,7 @@ namespace framework {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("x", "input of test op");
AddOutput("y", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
......@@ -137,9 +135,7 @@ class CPUKernelTest : public OpKernel<float> {
class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker {
public:
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("xs", "inputs of test op").AsDuplicable();
AddInput("k", "input of test op");
AddOutput("ys", "outputs of test op").AsDuplicable();
......
......@@ -24,8 +24,7 @@ namespace framework {
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
SumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
......
......@@ -98,7 +98,7 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
float x_v[2] = {1.0, 2.0};
engine_->SetInputFromCPU("x", reinterpret_cast<void*>(&x_v),
2 * sizeof(float));
2 * sizeof(float));
engine_->Execute(1);
LOG(INFO) << "to get output";
......
......@@ -166,6 +166,8 @@ function(op_library TARGET)
# NOTE(*): activation use macro to regist the kernels, set use_op manually.
if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "reduce")
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif()
......
......@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AccuracyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
// TODO(typhoonzero): support both inference value and indices.
AddInput("Out", "The network output of topk (inferences)");
AddInput("Indices", "The the network output of topk (indices)");
......
......@@ -19,19 +19,18 @@ limitations under the License. */
namespace paddle {
namespace operators {
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
OP_NAME##OpMaker(OpProto *proto, OpAttrChecker *op_checker) \
: ::paddle::framework::OpProtoAndCheckerMaker(proto, op_checker) { \
AddInput("X", "Input of " #OP_NAME "operator"); \
AddOutput("Out", "Output of" #OP_NAME "operator"); \
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
AddComment(#OP_COMMENT); \
} \
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
void Make() override { \
AddInput("X", "Input of " #OP_NAME "operator"); \
AddOutput("Out", "Output of" #OP_NAME "operator"); \
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
AddComment(#OP_COMMENT); \
} \
}
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
......@@ -204,8 +203,7 @@ $$out = \frac{x}{1 + |x|}$$
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Out", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
......@@ -220,8 +218,7 @@ $out = \max(x, \alpha * x)$
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Softshrink operator");
AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
......@@ -242,8 +239,7 @@ $$
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of HardShrink operator");
AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink")
......@@ -265,8 +261,7 @@ $$
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of BRelu operator");
AddOutput("Out", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu")
......@@ -284,8 +279,7 @@ $out = \max(\min(x, t_{min}), t_{max})$
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of SoftRelu operator");
AddOutput("Out", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu")
......@@ -301,8 +295,7 @@ $out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of ELU operator");
AddOutput("Out", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
......@@ -320,8 +313,7 @@ $out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Relu6 operator");
AddOutput("Out", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6")
......@@ -337,8 +329,7 @@ $out = \min(\max(0, x), 6)$
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Pow operator");
AddOutput("Out", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
......@@ -353,8 +344,7 @@ $out = x^{factor}$
class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of STanh operator");
AddOutput("Out", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input")
......@@ -372,8 +362,7 @@ $$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Out", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation")
......@@ -394,8 +383,7 @@ $$
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of HardSigmoid operator");
AddOutput("Out", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
......@@ -420,8 +408,7 @@ It is recommended to use the defaults for this activation.
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Swish operator");
AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
......
......@@ -66,8 +66,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdadeltaOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("AvgSquaredGrad", "(Tensor) Input average of squared gradient");
......
......@@ -67,8 +67,7 @@ class AdagradOp : public framework::OperatorWithKernel {
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdagradOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment");
......
......@@ -80,8 +80,7 @@ class AdamOp : public framework::OperatorWithKernel {
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdamOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate");
......
......@@ -74,8 +74,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdamaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate");
......
......@@ -123,8 +123,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ArrayToLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(std::vector<LodTensor>) A vector of tensors that is going to "
"be casted to a big LoDTensor.");
......
......@@ -94,8 +94,7 @@ class AssignOp : public framework::OperatorBase {
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
AssignOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, SelectedRows or LoDTensorArray) The input variable "
"could be LoDTensor, SelectedRows or LoDTensorArray.")
......
......@@ -45,8 +45,7 @@ class AssignValueOp : public framework::OperatorWithKernel {
class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AssignValueOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "(Tensor) Output tensor of assign_value operator.");
AddAttr<std::vector<int>>("shape",
"(vector<int>) "
......
......@@ -50,8 +50,7 @@ class AucOp : public framework::OperatorWithKernel {
class AucOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AucOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Out",
"A floating point 2D tensor, values are in the range [0, 1]."
"Each row is sorted in descending order. This input should be the"
......
......@@ -111,8 +111,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("param", "(Tensor), The parameter to be accumulated.");
AddInput("in_sum_1",
"(Tensor), A tensor used to store the parameter "
......
......@@ -126,8 +126,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BatchNormOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddAttr<bool>("is_test", "").SetDefault(false);
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "")
......
......@@ -53,8 +53,7 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() final {
AddInput("Input",
"(Tensor) Tensor "
"whose input_dim_idx'th dimension specifies the batch_size");
......@@ -68,7 +67,11 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("output_dim_idx",
"(int, default 0) The index of output's batch size dimension")
.SetDefault(0);
Apply();
}
protected:
virtual void Apply() = 0;
};
} // namespace operators
......
......@@ -134,8 +134,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
BeamSearchDecodeOpProtoMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Ids",
"(LodTensorArray)"
"score of the candidate words in each step");
......
......@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) {
class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
// inputs and outputs stored in proto
AddInput("pre_ids", "ids in previous step");
AddInput("ids", "a LoDTensor of shape of [None,k]");
......
......@@ -41,8 +41,7 @@ class BilinearInterpOp : public framework::OperatorWithKernel {
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BilinearInterpOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)");
......
......@@ -65,8 +65,7 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BilinearTensorProductOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The first input of bilinear_tensor_product operator.");
AddInput("Y", "The second input of bilinear_tensor_product operator.");
AddInput("Weight",
......
......@@ -182,8 +182,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"DistMat",
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
......
......@@ -60,8 +60,7 @@ class BoxCoderOp : public framework::OperatorWithKernel {
class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BoxCoderOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"PriorBox",
"(Tensor, default Tensor<float>) "
......
......@@ -21,8 +21,7 @@ namespace operators {
class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CastOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input tensor of cast op");
AddOutput("Out", "The output tensor of cast op");
AddAttr<int>("out_dtype", "output data type");
......
......@@ -50,8 +50,7 @@ class ChannelCloseOpOpInferShape : public framework::InferShapeBase {
class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelCloseOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kChannel,
"The Channel Variable that should be closed by"
" the ChannelClose Op.");
......
......@@ -91,8 +91,7 @@ class ChannelCreateOpOpInferShape : public framework::InferShapeBase {
class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelCreateOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput(kOutput,
"The object of a Channel type created by ChannelCreate Op.");
AddAttr<int>("capacity", "The size of the buffer of Channel.")
......
......@@ -72,8 +72,7 @@ class ChannelRecvOp : public framework::OperatorBase {
class ChannelRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelRecvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(Channel,
"(Channel) A variable which \"receives\" the a value sent"
"to it by a channel_send op.")
......
......@@ -57,8 +57,7 @@ class ChannelSendOp : public framework::OperatorBase {
class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelSendOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(Channel,
"(Channel) A variable which \"sends\" the passed in value to "
"a listening receiver.")
......
......@@ -66,8 +66,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChunkEvalOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Inference",
"(Tensor, default: Tensor<int64_t>). "
"Predictions from the network.");
......
......@@ -37,8 +37,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipByNormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9].");
......
......@@ -38,8 +38,7 @@ class ClipOp : public framework::OperatorWithKernel {
template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor)The input of clip op."
"The number of dimensions must be between [1, 9].");
......
......@@ -21,8 +21,7 @@ namespace operators {
template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CompareOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
OpComment comment;
AddInput("X",
string::Sprintf("(LoDTensor) the left hand operand of %s operator",
......
......@@ -63,8 +63,7 @@ class ConcatOp : public framework::OperatorWithKernel {
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConcatOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator.");
AddAttr<int>("axis",
......
......@@ -108,8 +108,7 @@ class ConditionalBlockOp : public ConditionalOp {
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ConditionalBlockOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The conditional variable of this operator. If X is empty, the "
"whole sub-block will not be executed.")
......
......@@ -106,8 +106,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
library);
}
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv2DOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......@@ -200,8 +199,7 @@ $$
)DOC");
}
Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv3DOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......
......@@ -60,12 +60,12 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim,
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class ConvOp : public framework::OperatorWithKernel {
......
......@@ -75,8 +75,7 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConvShiftOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
"where B is the batch size and M is the data dimension.");
......
......@@ -84,9 +84,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
layout_, library_);
}
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv2DTransposeOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution transpose operator. "
......@@ -168,9 +166,7 @@ Example:
)DOC");
}
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv3DTransposeOpMaker::Make() {
AddInput("Input",
"(Tensor) The input tensor of convolution transpose operator."
"The format of input tensor is NCDHW. Where N is batch size, C is "
......
......@@ -30,12 +30,12 @@ using DDim = framework::DDim;
// operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class ConvTransposeOp : public framework::OperatorWithKernel {
......
......@@ -62,8 +62,7 @@ class CosSimOp : public framework::OperatorWithKernel {
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CosSimOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op.");
AddOutput("Out", "The output of cos_sim op.");
......
......@@ -18,8 +18,7 @@ namespace paddle {
namespace operators {
class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CRFDecodingOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Emission",
"(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape "
"[N x D] where N is the size of the mini-batch and D is the total "
......
......@@ -52,8 +52,7 @@ class CropOp : public framework::OperatorWithKernel {
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CropOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7).");
......
......@@ -111,8 +111,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N x D],"
" where N is the batch size and D is the number of classes. "
......
......@@ -44,8 +44,7 @@ class CTCAlignOp : public framework::OperatorWithKernel {
class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is "
"[Lp, 1], where Lp is the sum of all input sequences' length.");
......
......@@ -29,8 +29,7 @@ class CumOp : public framework::OperatorWithKernel {
class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CumsumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Cumsum operator");
AddOutput("Out", "Output of Cumsum operator");
AddAttr<int>("axis",
......
......@@ -62,8 +62,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DecayedAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment");
......
......@@ -34,8 +34,7 @@ class DeleteVarOp : public framework::OperatorBase {
class DeleteVarOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
DeleteVarOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of delete op").AsDuplicable();
AddComment(R"DOC(
Delete Operator.
......
......@@ -29,129 +29,127 @@ namespace paddle {
namespace operators {
namespace detail {
using VarMsg = sendrecv::VariableMessage;
void GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) {
auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim);
}
const framework::LoD lod = tensor.lod();
if (lod.size() > 0) {
request->set_lod_level(lod.size());
for (auto& each : lod) {
VarMsg::LodData* lod_inner = request->add_lod();
for (auto& d : each) {
lod_inner->add_lod_data(d);
}
}
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
*payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, *payload, boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
gpu_dev_ctx.stream());
ctx.Wait();
#endif
} else {
*payload = tensor.data<void>();
}
*payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
}
void GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) {
auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_lod_level(0);
request->set_slr_height(slr->height());
for (auto& dim : framework::vectorize(slr->value().dims())) {
request->add_dims(dim);
}
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
*payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, *payload,
boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()), copy_size,
gpu_dev_ctx.stream());
ctx.Wait();
#endif
} else {
*payload = slr->mutable_value()->data<void>();
}
*payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
using VarMsg = sendrecv::VariableMessage;
// When using GPU, need to free the copied CPU buffer
// when the ByteBuffer destroies
// TODO(typhoonzero): add unref here, if we have dependent
// parallelism execution, need to know when to free the tensor.
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {};
auto buffer = std::unique_ptr<char[]>(new char[1024]);
void* buf = buffer.get();
VarMsg request;
void* payload = nullptr;
size_t payload_size;
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
request.set_varname(name);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
e.WriteBool(VarMsg::kProfileFieldNumber, platform::IsProfileEnabled());
request.set_profile(platform::IsProfileEnabled());
if (!out_name.empty()) {
request.set_out_varname(out_name);
}
e.WriteString(VarMsg::kVarnameFieldNumber, name);
if (var->IsType<framework::LoDTensor>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
request.set_type(::sendrecv::LOD_TENSOR);
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
} else if (var->IsType<framework::SelectedRows>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
if (!out_name.empty()) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
if (platform::is_gpu_place(ctx.GetPlace())) {
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
}
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR: {
auto tensor = var->Get<framework::LoDTensor>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(tensor.type()));
for (auto& dim : framework::vectorize(tensor.dims())) {
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
auto lod = tensor.lod(); // std::vector<Vector<size_t>>
if (lod.size() > 0) {
e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size());
for (auto& each : lod) {
e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber,
2 + // tag + varintlength of submessage
1 + // kLodDataFieldNumber
each.size());
// auto copied from GPU
for (auto& d : each) {
e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d);
}
}
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()),
copy_size, gpu_dev_ctx.stream());
ctx.Wait();
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
#endif
} else {
payload = tensor.data<void>();
}
payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
case framework::proto::VarType_Type_SELECTED_ROWS: {
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
auto* slr = var->GetMutable<framework::SelectedRows>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(slr->value().type()));
for (auto& dim : framework::vectorize(slr->value().dims())) {
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size =
tensor->numel() * framework::SizeOfType(tensor->type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()),
copy_size, gpu_dev_ctx.stream());
ctx.Wait();
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
#endif
} else {
payload = slr->mutable_value()->data<void>();
}
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
default:
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
break;
}
std::string header;
request.AppendToString(&header);
auto buffer = std::unique_ptr<char[]>(new char[1024]);
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
......@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
static_cast<char*>(payload)),
::grpc::Slice::STEAL_REF);
if (framework::ToVarType(var->Type()) ==
framework::proto::VarType_Type_SELECTED_ROWS) {
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
// NOTE: rows is of type int64_t
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
......@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
grpc_slice_new_with_user_data(
const_cast<void*>(
reinterpret_cast<const void*>(slr->rows().data())),
rows_memory_size,
[](void* backing) {
// TODO(typhoonzero): add unref here, same as above.
},
rows_memory_size, [](void* backing) {},
const_cast<char*>(
reinterpret_cast<const char*>(slr->rows().data()))),
::grpc::Slice::STEAL_REF);
......
......@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// serialize var to ByteBuffer
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({4, 8, 4, 2}));
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
int tensor_numel = 4 * 8 * 4 * 2;
int tensor_numel = 512 * 8 * 4 * 2;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
tensor->mutable_data<float>(place);
......@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 0);
EXPECT_EQ(varmsg.dims()[0], 4);
EXPECT_EQ(varmsg.dims()[0], 512);
EXPECT_EQ(varmsg.dims()[1], 8);
EXPECT_EQ(varmsg.dims()[2], 4);
EXPECT_EQ(varmsg.dims()[3], 2);
......
......@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
}
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int length = 0;
if (!input->ReadVarintSizeAsInt(&length)) {
int num_bytes = 0;
if (!input->ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
for (int i = 0; i < length; i++) {
int start_pos = input->CurrentPosition();
while (input->CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input->ReadVarint64(&v)) {
return false;
return tag;
}
lod->push_back(v);
}
......@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) {
break;
}
case sendrecv::VariableMessage::kTypeFieldNumber: {
uint64_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
uint32_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
......@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) {
break;
}
case sendrecv::VariableMessage::kDataTypeFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
uint32_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
......@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
// packed
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int length = 0;
if (!input.ReadVarintSizeAsInt(&length)) {
int num_bytes = 0;
if (!input.ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
for (int i = 0; i < length; i++) {
int start_pos = input.CurrentPosition();
while (input.CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
......@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kLodLevelFieldNumber: {
......@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) {
meta_.varname() != "",
"meta info should be got first!");
int length = 0;
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
......@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) {
if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0,
"lod info should be got first!");
if (!CopyLodTensorData(&input, *dev_ctx_, dims, length)) {
if (!CopyLodTensorData(&input, *dev_ctx_, dims, num_bytes)) {
return tag;
}
break;
}
if (meta_.type() == sendrecv::SELECTED_ROWS) {
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, length)) {
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, num_bytes)) {
return tag;
}
break;
......@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
meta_.varname() != "",
"meta info should be got first!");
int length = 0;
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (!CopySelectRowsData(&input, *dev_ctx_, length)) {
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return tag;
}
break;
......
......@@ -78,8 +78,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("DetectRes",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: "
......
......@@ -37,8 +37,7 @@ class DropoutOp : public framework::OperatorWithKernel {
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op.");
AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate();
......
......@@ -49,8 +49,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Hyps",
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"The indices for hypothesis strings.");
......
......@@ -14,26 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseAddOpMaker : public ElementwiseOpMaker {
public:
ElementwiseAddOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Add", "Out = X + Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp,
ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y");
REGISTER_OP_CPU_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,26 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
public:
ElementwiseDivOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Div", "Out = X / Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp,
ops::ElementwiseDivOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
REGISTER_OP_CPU_KERNEL(
elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Max", "Out = max(X, Y)");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_max, ops::ElementwiseOp,
ops::ElementwiseMaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)");
REGISTER_OP_CPU_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMinOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMinOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Max", "Out = min(X, Y)");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_min, ops::ElementwiseOp,
ops::ElementwiseMinOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)");
REGISTER_OP_CPU_KERNEL(
elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,27 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMulOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Mul", "Out = X \\odot\\ Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp,
ops::ElementwiseMulOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_mul, "Mul", "Out = X \\odot\\ Y");
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -54,8 +54,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() final {
AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op.");
......@@ -64,12 +63,12 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
"for broadcasting Y onto X.")
.SetDefault(-1)
.EqualGreaterThan(-1);
comment_ = R"DOC(
Limited Elementwise {name} Operator.
AddComment(string::Sprintf(R"DOC(
Limited Elementwise %s Operator.
The equation is:
$${equation}$$
$$%s$$
$X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be
smaller than or equal to the dimensions of $X$.
......@@ -100,26 +99,13 @@ For example
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$.
)DOC";
AddComment(comment_);
)DOC",
GetName(), GetEquation()));
}
protected:
std::string comment_;
void Replace(std::string* src, std::string from, std::string to) {
std::size_t len_from = std::strlen(from.c_str());
std::size_t len_to = std::strlen(to.c_str());
for (std::size_t pos = src->find(from); pos != std::string::npos;
pos = src->find(from, pos + len_to)) {
src->replace(pos, len_from, to);
}
}
void SetComment(std::string name, std::string equation) {
Replace(&comment_, "{name}", name);
Replace(&comment_, "{equation}", equation);
}
virtual std::string GetName() const = 0;
virtual std::string GetEquation() const = 0;
};
class ElementwiseOpGrad : public framework::OperatorWithKernel {
......@@ -152,3 +138,16 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
};
} // namespace operators
} // namespace paddle
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
......@@ -13,17 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise_pow_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwisePowOpMaker : public ElementwiseOpMaker {
public:
ElementwisePowOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Pow", "Out = X ^ Y");
AddComment(comment_);
}
protected:
std::string GetName() const override { return "Pow"; }
std::string GetEquation() const override { return "Out = X ^ Y"; }
};
} // namespace operators
} // namespace paddle
......
......@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseSubOpMaker : public ElementwiseOpMaker {
public:
ElementwiseSubOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Sub", "Out = X - Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_sub, ops::ElementwiseOp,
ops::ElementwiseSubOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_sub_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_sub, "Sub", "Out = X - Y");
REGISTER_OP_CPU_KERNEL(
elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -56,8 +56,7 @@ class ExpandOp : public framework::OperatorWithKernel {
class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded.");
......
......@@ -72,8 +72,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
layout, library);
}
FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void FCOpMaker::Make() {
AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
AddInput("W", "(Tensor), The second input tensor of fc op.");
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
......
......@@ -45,7 +45,7 @@ class FCOpGrad : public framework::OperatorWithKernel {
class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FCOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
} // namespace operators
......
......@@ -66,8 +66,7 @@ class FeedOp : public framework::OperatorBase {
class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FeedOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of feed op");
AddOutput("Out", "The output of feed op");
AddAttr<int>("col", "(int) The column of feed");
......
......@@ -66,8 +66,7 @@ class FetchOp : public framework::OperatorBase {
class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FetchOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of fetch op");
AddOutput("Out", "The output of fetch op");
AddAttr<int>("col", "(int) The column of fetch");
......
......@@ -30,9 +30,8 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
};
class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
protected:
void Apply() override {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
......
......@@ -59,8 +59,7 @@ class FillConstantOp : public framework::OperatorBase {
class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillConstantOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
......
......@@ -82,8 +82,7 @@ class FillOp : public framework::OperatorBase {
class FillOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddComment(R"DOC(Fill operator
Fill an tensor with `value` and `shape`. The type of the tensor is specify by
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册