提交 b61cf7ac 编写于 作者: L luotao1

Merge branch 'develop' into expand

......@@ -138,12 +138,6 @@ else()
set(THIRD_PARTY_BUILD_TYPE Release)
endif()
if(WITH_MKL)
option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF)
if (MKL_SPLIT_GEMM)
add_definitions(-DPADDLE_MKL_SPLIT_GEMM)
endif()
endif()
set(WITH_MKLML ${WITH_MKL})
if (NOT DEFINED WITH_MKLDNN)
if (WITH_MKL AND AVX2_FOUND)
......
......@@ -54,7 +54,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "a29d8487a63afca3d5b8c5bbdbb473cf8ccc6e51"
GIT_TAG "64e03a1939e0d526aa8e9f2e3f7dc0ad8d372944"
PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
# Distributed Training with NCCL2
We design a pattern that can enable training with `ParallelExecutor` and
using [NCCL2](https://developer.nvidia.com/nccl) as it's collective
use [NCCL2](https://developer.nvidia.com/nccl) as it's collective
communication library.
In `ParallelExecutor` we can use `AllReduce` or `Reduce` and `Broadcast`
......@@ -9,14 +9,14 @@ to do multi GPU training. And if we initialize NCCL2 communicators as
ranks in a distributed environment, we can simply run the `ParallelExecutor`
as a distributed program! The only thing that may be different than in
the single node version is that we need to broadcast the NCCL unique ID
to all the nodes, and initialize communicators using that ID, so NCCL2
will know each other as ranks.
to all the nodes and initialize communicators using that ID, so NCCL2
can know each other as ranks.
To achieve this feature, we introduce a new operator: `gen_nccl_id` op,
so we are ***not*** "bind to" running NCCL2 with MPI, we can run it in
what ever platform you like.
whatever platform you like.
It have two running modes:
It has two running modes:
1. Generate and broadcast mode, which should be used on trainer 0;
1. Listen and fetch mode, which should be used on trainers other than 0.
......@@ -29,7 +29,7 @@ initialize NCCL communicator objects.
<img src="src/ncc2_design.png">
The above figure indicates the general process when training with NCCL2
distributed. Each trainer have the number of communicators equal to the
distributed. Each trainer has the number of communicators equal to the
number of GPUs, but the ranks should match the global ranks number: here
we have total 8 GPUs, so `nranks==8`, for each trainer, the ranks should
be from 0 ~ 3 on trainer 0 and 4 ~ 7 on trainer 1.
......@@ -36,19 +36,19 @@
<tbody>
<tr>
<td>OpProtoMake定义 </td>
<td>`.cc`文件,Backward Op不需要定义OpProtoMake </td>
<td>.cc 文件,Backward Op不需要定义OpProtoMake </td>
</tr>
<tr>
<td>Op定义 </td>
<td> `.cc`文件</td>
<td> .cc 文件</td>
</tr>
<tr>
<td>Kernel实现 </td>
<td> CPU、CUDA共享Kernel实现在`.h`文件中,否则,CPU 实现在`.cc`文件中,CUDA 实现在`.cu`文件中。</td>
<td> CPU、CUDA共享Kernel实现在.h 文件中,否则,CPU 实现在.cc 文件中,CUDA 实现在.cu 文件中。</td>
</tr>
<tr>
<td>注册Op </td>
<td> Op注册实现在`.cc`文件;Kernel注册CPU实现在`.cc`文件中,CUDA实现在`.cu`文件中</td>
<td> Op注册实现在.cc 文件;Kernel注册CPU实现在.cc 文件中,CUDA实现在.cu 文件中</td>
</tr>
</tbody>
</table>
......@@ -391,7 +391,7 @@ PADDLE_ENFORCE(ctx->HasInput("X"), "");
```
问题示例2 :提示信息过于简单
```
PADDLE_ENFORCE(i != nullptr, "I must be set"); // I是什么?
PADDLE_ENFORCE(i != nullptr, "i must be set"); // i是什么?
```
2. 在报错信息中使用开发人员定义的变量缩写,不易理解!
......
# Distributed Training with NCCL2 and RDMA
When doing distributed multi-GPU training, network bandwith often becomes the
bottle neck. We introduce a way to use NCCL2 to do such training job to
achieve best performace.
When doing distributed multi-GPU training, network bandwidth often becomes the
bottleneck. We introduce a way to use NCCL2 to do such training job to
achieve best performance.
## Prepare Hardwares with RDMA and Multiple GPUs
## Prepare Hardware with RDMA and Multiple GPUs
I'm using two Linux servers each of them is installed with 8 GPUs and
I'm using two Linux servers each of them installed with 8 GPUs and
one 100Gb RDMA card.
Base environment is:
......@@ -25,7 +25,7 @@ In general, the steps including:
1. Use docker to run tests and make sure GPUs and RDMA can work inside
the container.
I'll ommit section "Install GPU drivers" because we can find it easily
I'll omit the section "Install GPU drivers" because we can find it easily
somewhere else.
### Install RDMA drivers
......@@ -33,7 +33,7 @@ somewhere else.
For my case, I've got two machines with device
"Mellanox Technologies MT27700 Family [ConnectX-4]" installed. The OS was
"CentOS 7.4" and I updated the kernel to version 4.4 so that docker can
work with latest overlay2 filesystem.
work with the latest overlay2 filesystem.
***NOTE: before you start, make sure you have a way to get a console
of the server other than ssh because we may need to re-configure the
......@@ -45,14 +45,14 @@ network device.***
1. Run `./mlnxofedinstall --add-kernel-support` in the software package.
1. Run `/etc/init.d/openibd restart` to make everything work, note that
this operation may cause the network goes down if you are using this
RDMA device as default network device and use ssh to login the server.
RDMA device as default network device and use ssh to log in the server.
1. Re-configure the network interface, for example:
`ifconfig eth2 192.168.16.30/20 up`, then add routes if needed:
`ip route add default via 192.168.16.1 dev eth2`.
1. Do the same thing on the other node.
1. Use `ping` to test if the two nodes have typical ICMP connection.
1. Use either `udaddy` or `ib_write_bw` to test the network connection is
ready and have the desired bandwith.
ready and have the desired bandwidth.
### Prepare Docker Image to Run RDMA Programs
......@@ -60,7 +60,7 @@ network device.***
package in it.
1. Start a docker container and mount GPU driver libs into it (you can
skip this step if you are using nvidia-docker).
1. Mount RDMA dirvers and libs into the docker image (see below section),
1. Mount RDMA drivers and libs into the docker image (see below section),
also `udaddy` and `ib_write_bw` if needed.
1. Mount GPU devices and RDMA devices into the container using `--device`
or just use privileged mode `--privileged`.
......
......@@ -162,6 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace framework {
template <typename T, size_t N>
class Array {
static_assert(N > 0, "The size of array must be larger than 0");
public:
HOSTDEVICE Array() {}
HOSTDEVICE explicit Array(const T &val) {
for (size_t i = 0; i < N; ++i) data_[i] = val;
}
HOSTDEVICE const T *Get() const { return data_; }
HOSTDEVICE T *GetMutable() { return data_; }
HOSTDEVICE T &operator[](size_t index) { return data_[index]; }
HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; }
HOSTDEVICE constexpr size_t size() const { return N; }
private:
T data_[N];
};
} // namespace framework
} // namespace paddle
......@@ -129,10 +129,6 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
"Optimized for variable")
.SetDefault({});
AddAttr<std::vector<std::string>>(OpCreationCallstackAttrName(),
"Callstack for Op Creatation.")
.SetDefault({});
Validate();
}
......
......@@ -39,7 +39,6 @@ class OpProtoAndCheckerMaker {
public:
static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; }
static const char *OpCreationCallstackAttrName() { return "op_callstack"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
......
......@@ -11,17 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -129,48 +127,19 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
try {
if (VLOG_IS_ON(4)) {
VLOG(4) << place << " " << DebugStringEx(&scope);
}
if (platform::is_gpu_place(place)) {
VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place);
PADDLE_THROW("Cannot run operator on place %s", place);
#else
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
#endif
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
if (VLOG_IS_ON(3)) {
VLOG(3) << place << " " << DebugStringEx(&scope);
}
} catch (platform::EnforceNotMet exception) {
if (Attrs().count("sub_block") != 0) {
throw exception;
}
auto& callstack = Attr<std::vector<std::string>>(
OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (callstack.empty()) {
throw exception;
}
std::ostringstream sout;
sout << "Invoke operator " << Type() << " error.\n";
sout << "Python Callstacks: \n";
for (auto& line : callstack) {
sout << line;
}
sout << "C++ Callstacks: \n";
sout << exception.err_str_;
exception.err_str_ = sout.str();
throw exception;
} catch (...) {
std::rethrow_exception(std::current_exception());
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
VLOG(3) << place << " " << DebugStringEx(&scope);
}
bool OperatorBase::HasInputs(const std::string& name) const {
......@@ -198,7 +167,7 @@ const std::vector<std::string>& OperatorBase::Inputs(
}
bool OperatorBase::HasOutputs(const std::string& name) const {
if (outputs_.end() != outputs_.find(name)) {
if (outputs_.find(name) != outputs_.end()) {
return true;
} else {
return false;
......
......@@ -31,7 +31,8 @@ size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
}
void* Tensor::mutable_data(platform::Place place, std::type_index type) {
void* Tensor::mutable_data(platform::Place place, std::type_index type,
size_t requested_size) {
if (holder_ != nullptr) {
holder_->set_type(type);
}
......@@ -39,7 +40,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) {
"When calling this method, the Tensor's numel must be "
"equal or larger than zero. "
"Please check Tensor::Resize has been called first.");
int64_t size = numel() * SizeOfType(type);
size_t size = requested_size ? requested_size : numel() * SizeOfType(type);
/* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) {
......@@ -68,10 +69,10 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) {
offset_);
}
void* Tensor::mutable_data(platform::Place place) {
void* Tensor::mutable_data(platform::Place place, size_t requested_size) {
PADDLE_ENFORCE(this->holder_ != nullptr,
"Cannot invoke mutable data if current hold nothing.");
return mutable_data(place, holder_->type());
return mutable_data(place, holder_->type(), requested_size);
}
Tensor& Tensor::ShareDataWith(const Tensor& src) {
......
......@@ -89,22 +89,24 @@ class Tensor {
* @note If not exist, then allocation.
*/
template <typename T>
T* mutable_data(platform::Place place);
T* mutable_data(platform::Place place, size_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type);
void* mutable_data(platform::Place place, std::type_index type,
size_t requested_size = 0);
void* mutable_data(platform::Place place);
void* mutable_data(platform::Place place, size_t requested_size = 0);
/**
* @brief Return a pointer to mutable memory block.
*
* @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block.
* @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block.
* @param[in] requested_size The size of the block in bytes.
*
* @note If not exist, then allocation.
*/
template <typename T>
T* mutable_data(DDim dims, platform::Place place);
T* mutable_data(DDim dims, platform::Place place, size_t requested_size = 0);
/*! Return the dimensions of the memory block. */
const DDim& dims() const;
......
......@@ -46,16 +46,17 @@ inline T* Tensor::data() {
}
template <typename T>
inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
inline T* Tensor::mutable_data(DDim dims, platform::Place place,
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims);
return mutable_data<T>(place);
return mutable_data<T>(place, requested_size);
}
template <typename T>
inline T* Tensor::mutable_data(platform::Place place) {
inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>(mutable_data(place, typeid(T)));
return reinterpret_cast<T*>(mutable_data(place, typeid(T), requested_size));
}
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/attention_lstm_op.h"
#include <sys/time.h>
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
"Input(LSTMWeight) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
"Input(LSTMBias) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Input(AttentionWeight) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
"Output(AttentionedX) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
"Output(AttentionFCOut) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Output(LSTMX) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Output(LSTMOUT) of AttentionLSTM should not be null.");
auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1];
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto w_dims = ctx->GetInputDim("LSTMWeight");
const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D);
auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
"AttentionBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
"AttentionBias shapes must be 1 * 1.");
}
if (ctx->HasInput("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
}
if (ctx->HasInput("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias");
PADDLE_ENFORCE(
ctx->HasInput("AttentionScalar"),
"AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
}
framework::DDim out_dims({x_dims[0], D});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
ctx->SetOutputDim("LSTMX", {1, M});
ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
// AttentionFCOut should be reshape as (maxseqlen,1) in runtime
ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell");
}
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void AttentionLSTMOpMaker::Make() {
AddInput("X",
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
AddInput("C0",
"(Tensor) LSTM C0"
"This is a tensor with shape (N x D), where N is the batch size, D "
"is the gate size."
"C0 is necessary because of attention.");
AddInput("H0",
"(Tensor, optional) LSTM H0"
"This is a tensor with shape (N x D), where N is the "
"batch size and D is the gate size.")
.AsDispensable();
AddInput("AttentionWeight",
"(Tensor) the weights of attention fc. Always relu the fc result."
"The shape is ((M+D) x 1), where M is the dim size of x, D is the "
"gate size of LSTM.");
AddInput("AttentionBias",
"(Tensor, optional) the bias of attention fc."
"The shape is (1 x 1)")
.AsDispensable();
AddInput("AttentionScalar",
"(Tensor, optional) the scalar on the result of attentioned fc. "
"Always relu the Scalar."
"The shape is (1 x 1)")
.AsDispensable();
AddInput("AttentionScalarBias",
"(Tensor, optional) the scalar bias of attention fc."
"The shape is (1 x 1)")
.AsDispensable();
AddInput("LSTMWeight",
"(Tensor) the combined weight of LSTM"
" - The shape is ((D+M) x 4D), where D is the hidden gate size, M "
"is the dim size of x"
" - Weight = {W_forget, W_input, W_output, W_cell}");
AddInput("LSTMBias",
"(Tensor) the combined bias of LSTM, shape (1x4D)."
"Note: we should add the bias of hidden and context accorindg to "
"the same gate: "
"{B_forget, B_input, B_output, B_cell}");
AddOutput("Hidden",
"(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("AttentionedX",
"(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size.")
.AsIntermediate();
AddOutput("AttentionFCOut",
"(Tensor) (max_seq_len, 1), compute at each step.")
.AsIntermediate();
AddOutput("LSTMX",
"(Tensor) the input X of LSTM for each step."
"Shape is (1 x M), where M is the x frame size")
.AsIntermediate();
AddOutput(
"LSTMOUT",
"(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step."
"Shape is (1 x 4D), where M is the x frame size")
.AsIntermediate();
AddAttr<std::string>("gate_activation",
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.")
.SetDefault("sigmoid")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("cell_activation",
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("candidate_activation",
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by default.")
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC(
Attention Long-Short Term Memory (LSTM) Operator.
Attention part:
concat( x(seqlen * M), expand( cell_t-1(1,D) ) ) => tmp(seqlen*(M+D))
tmp(seqlen*(M+D)) * fc((M+D)*1) => fcout(seqlen*1) with bias, relu
fcout(seqlen*1) * scalar => fcout(seqlen*1) with bias, relu
dotmul and sum pool ( fcout(seqlen*1), x(seqlen * M) ) => lstm_x_t(1, M)
LSTM part:
use lstm_x_t as input and compute as standard LSTM.
)DOC");
}
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
if (bias) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] + bias[0];
}
math::vec_relu<T>(n, y, y);
} else {
math::vec_relu<T>(n, x, y);
}
}
template <typename DeviceContext, typename T>
inline void vec_softmax(const math::BlasT<DeviceContext, T>& blas, const int n,
const T* x, T* y) {
T scalar = x[0];
// max
for (int i = 1; i < n; ++i) {
scalar = scalar < x[i] ? x[i] : scalar;
}
// sub
for (int i = 0; i < n; ++i) {
y[i] = x[i] - scalar;
}
// exp
blas.VEXP(n, y, y);
// sum
scalar = T(0);
for (int i = 0; i < n; ++i) {
scalar += y[i];
}
// scale
blas.SCAL(n, static_cast<T>(1) / scalar, y);
}
template <typename T>
class AttentionLSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
auto* h0 = ctx.Input<Tensor>("H0");
auto* c0 = ctx.Input<Tensor>("C0");
auto* atten_w = ctx.Input<Tensor>("AttentionWeight");
auto* atten_b = ctx.Input<Tensor>("AttentionBias");
auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar");
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias");
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight");
auto* lstm_b = ctx.Input<Tensor>("LSTMBias");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
auto* atted_x = ctx.Output<Tensor>("AttentionedX");
auto* fc_out = ctx.Output<Tensor>("AttentionFCOut");
auto* lstm_x = ctx.Output<Tensor>("LSTMX");
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT");
// some shape should be reshape here since infershape can not get lod info
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; // batch size
auto x_dims = x->dims(); // T x M
auto w_dims = lstm_w->dims(); // (D+M) x 4D
const int total_T = x_dims[0];
const int M = x_dims[1]; // x frame size
const int D = w_dims[1] / 4; // gate frame size
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = w_dims[1];
int max_seq_len = x_lod[0][1];
for (int i = 1; i < N; ++i) {
int len = x_lod[0][i + 1] - x_lod[0][i];
max_seq_len = max_seq_len < len ? len : max_seq_len;
}
PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1.");
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1});
math::VecActivations<T> act_functor;
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
act_gate = act_functor(ctx.Attr<std::string>("gate_activation"));
act_cell = act_functor(ctx.Attr<std::string>("cell_activation"));
act_cand = act_functor(ctx.Attr<std::string>("candidate_activation"));
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;
const T* c0_data = c0->data<T>();
const T* lstm_w_data = lstm_w->data<T>();
const T* lstm_b_data = lstm_b->data<T>();
const T* atten_w_data = atten_w->data<T>();
const T* atten_b_data = atten_b ? atten_b->data<T>() : NULL;
const T* atten_scalar_data = atten_scalar ? atten_scalar->data<T>() : NULL;
const T* atten_scalar_bias_data =
atten_scalar_bias ? atten_scalar_bias->data<T>() : NULL;
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
T* atted_x_data = atted_x->mutable_data<T>(ctx.GetPlace());
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
atted_x_data, atten_b_data);
const T* cur_atten_x_data = atted_x_data;
const T* cur_x_data = x_data;
const T* prev_cell_data = NULL;
const T* prev_hidden_data = NULL;
T* cur_cell_out_data = cell_out_data;
T* cur_hidden_out_data = hidden_out_data;
for (int i = 0; i < N; ++i) {
int seq_len = x_lod[0][i + 1] - x_lod[0][i];
prev_cell_data = c0_data + i * D;
prev_hidden_data = h0_data ? h0_data + i * D : NULL;
for (int step = 0; step < seq_len; ++step) {
/// 1. compute attention vector
// 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt
T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
// 1b. add cell bias and relu
bias_relu<T>(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data);
// 1c. fc scalar
if (atten_scalar_data) {
blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
fc_out_data);
}
// 1d. softmax
vec_softmax<DeviceContext, T>(blas, seq_len, fc_out_data, fc_out_data);
// mul x(seq_len*M) and sum pool
math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data,
cur_x_data, lstm_x_data);
/// 2. compute LSTM step
// lstm weight : concat[forget , input , output , tilde]
// shape : (D + M) x (4 * D)
// fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D
blas.MatMul(1, D4, M, lstm_x_data, lstm_w_data + D * D4, lstm_out_data);
if (prev_hidden_data) {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
prev_hidden_data, D, lstm_w_data, D4, static_cast<T>(1),
lstm_out_data, D4);
}
// since input is 1xM, so can use add bias
blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data);
// gate act: sigmoid
act_gate(D3, lstm_out_data, lstm_out_data);
// candicate act: tanh
act_cand(D, lstm_out_data + D3, lstm_out_data + D3);
// a = forget * prev_cell
blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);
// b = input * tilde
blas.VMUL(D, lstm_out_data + D, lstm_out_data + D3, lstm_out_data + D);
// cell_out = a + b
blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);
// state act tanh(cell_out) * output_gate
act_cell(D, cur_cell_out_data, lstm_out_data);
blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
prev_hidden_data = cur_hidden_out_data;
prev_cell_data = cur_cell_out_data;
cur_cell_out_data = cur_cell_out_data + D;
cur_hidden_out_data = cur_hidden_out_data + D;
}
cur_x_data = cur_x_data + seq_len * M;
cur_atten_x_data = cur_atten_x_data + seq_len;
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(attention_lstm, ops::AttentionLSTMOp,
ops::AttentionLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(attention_lstm, ops::AttentionLSTMKernel<float>,
ops::AttentionLSTMKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class AttentionLSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class AttentionLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
......@@ -135,7 +135,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Variance",
"The global variance (for training) "
"or estimated Variance (for testing)");
AddOutput("Y", "result after normalization");
AddOutput("Y", "result after normalization").Reuse("X");
AddOutput("MeanOut",
"Share memory with Mean. "
"Store the global mean when training")
......
......@@ -53,6 +53,18 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
key_ += "-BWD";
}
size_t GetDstMemorySize() const {
return conv_pd_->dst_primitive_desc().get_size();
}
size_t GetDiffWeightsMemorySize() const {
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size();
}
size_t GetDiffSourceMemorySize() const {
return conv_bwd_data_pd_->diff_src_primitive_desc().get_size();
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
......@@ -294,7 +306,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz =
......@@ -354,6 +365,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data));
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
// create reorder primitive if the input format is not the preferred one
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
......@@ -476,13 +489,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
T* input_grad_data = nullptr;
T* filter_grad_data = nullptr;
if (input_grad) {
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
}
if (filter_grad) {
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
}
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims());
......@@ -568,6 +574,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDiffDstMemoryFromWeightsPrimitive(
user_diff_dst_memory_p, pipeline);
const size_t size = handler.GetDiffWeightsMemorySize();
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
auto diff_weights_memory_p =
handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
reinterpret_cast<void*>(filter_grad_data));
......@@ -590,6 +599,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
pipeline);
const size_t size = handler.GetDiffSourceMemorySize();
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size);
auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
reinterpret_cast<void*>(input_grad_data));
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// #include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......
......@@ -90,6 +90,11 @@ class Blas {
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
template <typename T>
void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
int ldc) const;
#ifdef PADDLE_WITH_MKLML
template <typename T>
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
......@@ -109,6 +114,10 @@ class Blas {
void GEMM_FREE(T* data) const;
#endif
template <typename T>
void MatMul(const int M, const int N, const int K, const T* A, const T* B,
T* C) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a, bool trans_a,
const framework::Tensor& mat_b, bool trans_b, T alpha,
......@@ -140,10 +149,19 @@ class Blas {
template <typename T>
void VCOPY(int n, const T* x, T* y) const;
template <typename T>
void VEXP(int n, const T* x, T* y) const;
template <typename T>
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
T* C) const;
template <typename T>
T DOT(int n, const T* x, const T* y) const;
template <typename T>
void SCAL(int n, const T a, T* x) const;
template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
int K, T alpha, const T* A, const T* B, T beta, T* C,
......@@ -215,11 +233,26 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VCOPY<T>(args...);
}
template <typename... ARGS>
void VEXP(ARGS... args) const {
Base()->template VEXP<T>(args...);
}
template <typename... ARGS>
void GEMV(ARGS... args) const {
Base()->template GEMV<T>(args...);
}
template <typename... ARGS>
T DOT(ARGS... args) const {
return Base()->template DOT<T>(args...);
}
template <typename... ARGS>
void SCAL(ARGS... args) const {
Base()->template SCAL<T>(args...);
}
template <typename... ARGS>
void BatchedGEMM(ARGS... args) const {
Base()->template BatchedGEMM<T>(args...);
......
......@@ -73,6 +73,16 @@ struct CBlas<float> {
platform::dynload::cblas_sgemv(args...);
}
template <typename... ARGS>
static float DOT(ARGS... args) {
return platform::dynload::cblas_sdot(args...);
}
template <typename... ARGS>
static void SCAL(ARGS... args) {
platform::dynload::cblas_sscal(args...);
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_sgemm_batch(args...);
......@@ -87,6 +97,11 @@ struct CBlas<float> {
static void VMUL(ARGS... args) {
platform::dynload::vsMul(args...);
}
template <typename... ARGS>
static void VEXP(ARGS... args) {
platform::dynload::vsExp(args...);
}
};
template <>
......@@ -138,6 +153,16 @@ struct CBlas<double> {
platform::dynload::cblas_dgemv(args...);
}
template <typename... ARGS>
static double DOT(ARGS... args) {
return platform::dynload::cblas_ddot(args...);
}
template <typename... ARGS>
static void SCAL(ARGS... args) {
platform::dynload::cblas_dscal(args...);
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_dgemm_batch(args...);
......@@ -152,6 +177,11 @@ struct CBlas<double> {
static void VMUL(ARGS... args) {
platform::dynload::vdMul(args...);
}
template <typename... ARGS>
static void VEXP(ARGS... args) {
platform::dynload::vdExp(args...);
}
};
#else
......@@ -210,6 +240,9 @@ struct CBlas<platform::float16> {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
}
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
#ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
......@@ -217,64 +250,6 @@ struct CBlas<platform::float16> {
#endif
};
template <typename T>
inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa,
bool transb, const T &alpha, const T &beta) {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
std::abs<T>(alpha - static_cast<T>(1) >
std::numeric_limits<T>::epsilon()) ||
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
return false;
} else {
return true;
}
#endif
return false;
}
template <>
inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
bool transa, bool transb,
const platform::float16 &alpha,
const platform::float16 &beta) {
return false;
}
template <typename T>
inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
const T *A, int lda, const T *B, int ldb, T beta, T *C,
int ldc) {
#ifdef PADDLE_WITH_LIBXSMM
if (UseXSMM<T>(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
beta)) {
// Note: SMM use ColMajor
const char transa = 'N';
const char transb = 'N';
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
&beta, C, &ldc);
return;
}
#endif
#ifdef PADDLE_MKL_SPLIT_GEMM
constexpr int bs = 2;
if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
for (int off = 0; off < M; off += bs) {
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
}
return;
}
#endif
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
#ifdef PADDLE_WITH_MKLML
template <>
template <typename T>
......@@ -319,8 +294,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
GEMM_WARP<T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
......@@ -329,9 +304,20 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb,
T beta, T *C, int ldc) const {
GEMM_WARP<T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M,
int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb,
T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <typename DeviceContext>
......@@ -399,6 +385,47 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VEXP(n, x, y);
#else
// try to find if openblas support vexp
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
#endif
}
template <>
template <typename T>
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
#ifdef PADDLE_WITH_MKLML
return CBlas<T>::DOT(n, x, 1, y, 1);
#else
// try to find if openblas support cblas_dot
T sum = 0;
for (int i = 0; i < n; ++i) {
sum += x[i] * y[i];
}
return sum;
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::SCAL(n, a, x, 1);
#else
// try to find if openblas support cblas_scal
for (int i = 0; i < n; ++i) {
x[i] = a * x[i];
}
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
......@@ -440,6 +467,42 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const int M, const int N, const int K,
const T *A, const T *B, T *C) const {
this->template GEMM<T>(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C,
N);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N,
const int K, const T *A,
const T *B, T *C) const {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
// Since the matrix is very small,
// so the unit of calculation is already very fast,
// and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead,
// use xsmm directly.
// Note: SMM use ColMajor
const char transa = 'N';
const char transb = 'N';
const T alpha = static_cast<T>(1);
const T beta = static_cast<T>(0);
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
C, &N);
return;
#endif
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C, N);
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
template <typename T>
inline T sigmoid(T x) {
return 1. / (1. + exp(-x));
}
template <typename T>
inline T tanh(T x) {
return 2. * sigmoid(2. * x) - 1.;
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_identity(const int n, const T* x, T* y) {
// do nothing
return;
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_sigmoid(const int n, const T* x, T* y) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 1.0 / (1.0 + std::exp(-tmp));
}
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_tanh(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = tanh<T>(x[i]);
}
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_relu(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <>
inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x,
float* y) {
// TODO(TJ): complete me
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <>
inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
float* y) {
// TODO(TJ): complete me
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
class VecActivations {
public:
std::function<void(const int, const T*, T*)> operator()(
const std::string& type) {
if (type == "sigmoid") {
return vec_sigmoid<T, isa>;
} else if (type == "relu") {
return vec_relu<T, isa>;
} else if (type == "tanh") {
return vec_tanh<T, isa>;
} else if (type == "identity" || type == "") {
return vec_identity<T, isa>;
}
PADDLE_THROW("Not support type %s.", type);
}
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -25,17 +25,25 @@ namespace math {
template <typename DeviceContext, typename T>
inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = NULL) {
blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast<T>(1), X, W,
static_cast<T>(0), Y);
if (B) {
const T* B = NULL, bool relu = false) {
blas.MatMul(M, N, K, X, W, Y);
if (B == NULL) {
return;
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), B, Y + i * N);
}
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), B, Y + i * N);
}
if (!relu) {
return;
}
// TODO(TJ): fuse relu
LOG(FATAL) << "Not implemented!";
}
} // namespace math
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/stack_op.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
ops::StackGradOpDescMaker);
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel<plat::CPUDeviceContext, float>,
ops::StackKernel<plat::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(stack_grad,
ops::StackGradKernel<plat::CPUDeviceContext, float>,
ops::StackGradKernel<plat::CPUDeviceContext, double>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/stack_op.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel<plat::CUDADeviceContext, float>,
ops::StackKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(stack_grad,
ops::StackGradKernel<plat::CUDADeviceContext, float>,
ops::StackGradKernel<plat::CUDADeviceContext, double>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#ifdef __NVCC__
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/array.h"
#endif
namespace paddle {
namespace operators {
class StackOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0,
"Number of Inputs(X) must be larger than 0");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist.");
auto input_dims = ctx->GetInputsDim("X");
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
"Dims of all Inputs(X) must be the same");
}
// Only lod of X[0] would be shared with Y
ctx->ShareLoD("X", /*->*/ "Y");
int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
PADDLE_ENFORCE(
axis >= -(rank + 1) && axis < rank + 1,
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
if (axis < 0) axis += (rank + 1);
auto vec = framework::vectorize2int(input_dims[0]);
vec.insert(vec.begin() + axis, input_dims.size());
ctx->SetOutputDim("Y", framework::make_ddim(vec));
}
};
class StackOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of stack op.").AsDuplicable();
AddOutput("Y", "The output of stack op.");
AddAttr<int>("axis",
"The axis along which all of the Inputs(X) should be stacked.")
.SetDefault(0);
AddComment(R"DOC(
Stack Operator.
Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.
)DOC");
}
};
template <typename VecXType, typename T>
struct StackFunctor {
HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post)
: x_(x), y_(y), n_(n), post_(post) {}
HOSTDEVICE void operator()(int idx) {
int i = idx / (n_ * post_);
int which_x = idx / post_ - i * n_;
int x_index = i * post_ + idx % post_;
y_[idx] = x_[which_x][x_index];
}
private:
VecXType x_;
T *y_;
int n_;
int post_;
};
template <typename VecDxType, typename T>
struct StackGradFunctor {
HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post)
: dx_(dx), dy_(dy), n_(n), post_(post) {}
HOSTDEVICE void operator()(int idx) {
int i = idx / (n_ * post_);
int which_x = idx / post_ - i * n_;
int x_index = i * post_ + idx % post_;
dx_[which_x][x_index] = dy_[idx];
}
private:
VecDxType dx_;
const T *dy_;
int n_;
int post_;
};
template <typename DeviceContext, typename VecXType, typename T>
static inline void StackFunctorForRange(const DeviceContext &ctx,
const VecXType &x, T *y, int total_num,
int n, int post) {
platform::ForRange<DeviceContext> for_range(ctx, total_num);
for_range(StackFunctor<VecXType, T>(x, y, n, post));
}
template <typename DeviceContext, typename VecDxType, typename T>
static inline void StackGradFunctorForRange(const DeviceContext &ctx,
const VecDxType &dx, const T *dy,
int total_num, int n, int post) {
platform::ForRange<DeviceContext> for_range(ctx, total_num);
for_range(StackGradFunctor<VecDxType, T>(dx, dy, n, post));
}
template <typename DeviceContext, typename T>
class StackKernel : public framework::OpKernel<T> {
using Tensor = framework::LoDTensor;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto x = ctx.MultiInput<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += (x[0]->dims().size() + 1);
int n = static_cast<int>(x.size());
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
std::vector<const T *> x_datas(n);
for (int i = 0; i < n; i++) x_datas[i] = x[i]->data<T>();
int pre = 1, post = 1;
auto &dim = x[0]->dims();
for (auto i = 0; i < axis; ++i) pre *= dim[i];
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
int total_num = pre * n * post;
auto &dev_ctx = ctx.template device_context<DeviceContext>();
constexpr auto kMaxThreshold = 16;
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
n > kMaxThreshold) {
#ifdef __NVCC__
VLOG(10) << "Stack more than " << kMaxThreshold
<< " tensors on GPU may be slow.";
thrust::device_vector<const T *> device_x_vec(x_datas);
auto x_data_arr = device_x_vec.data().get();
#else
auto x_data_arr = x_datas.data();
#endif
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
#ifdef __NVCC__
// Wait() must be called because device_x_vec may be destructed before
// kernel ends
dev_ctx.Wait();
#endif
}
#ifdef __NVCC__
else { // NOLINT
framework::Array<const T *, kMaxThreshold> x_data_arr;
for (int i = 0; i < n; ++i) x_data_arr[i] = x_datas[i];
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
}
#endif
}
};
class StackOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@Grad) must exist.");
int axis = ctx->Attrs().Get<int>("axis");
auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = dy_dim.size();
PADDLE_ENFORCE(axis >= -rank && axis < rank,
"Attr(axis) must be inside [-rank, rank), where rank = %d",
rank);
if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(ctx->Outputs(framework::GradVarName("X")).size(),
static_cast<size_t>(dy_dim[axis]),
"Number of Outputs(X@Grad) is wrong");
auto vec = framework::vectorize2int(dy_dim);
vec.erase(vec.begin() + axis);
ctx->SetOutputsDim(
framework::GradVarName("X"),
std::vector<framework::DDim>(dy_dim[axis], framework::make_ddim(vec)));
}
};
class StackGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("stack_grad");
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
op->SetAttrMap(Attrs());
return op;
}
};
template <typename DeviceContext, typename T>
class StackGradKernel : public framework::OpKernel<T> {
using Tensor = framework::LoDTensor;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += dy->dims().size();
int n = dy->dims()[axis];
std::vector<T *> dx_datas(n); // NOLINT
for (int i = 0; i < n; i++) {
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
}
auto dy_data = dy->data<T>();
int pre = 1;
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
int total_num = dy->numel();
int post = total_num / (n * pre);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
constexpr auto kMaxThreshold = 16;
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
n > kMaxThreshold) {
#ifdef __NVCC__
VLOG(10) << "Stack more than " << kMaxThreshold
<< " tensors on GPU may be slow.";
thrust::device_vector<T *> device_dx_vec(dx_datas);
auto dx_data_arr = device_dx_vec.data().get();
#else
auto dx_data_arr = dx_datas.data();
#endif
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n,
post);
#ifdef __NVCC__
// Wait() must be called because device_dx_vec may be destructed before
// kernel ends
dev_ctx.Wait();
#endif
}
#ifdef __NVCC__
else { // NOLINT
framework::Array<T *, kMaxThreshold> dx_data_arr;
for (int i = 0; i < n; ++i) dx_data_arr[i] = dx_datas[i];
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n,
post);
}
#endif
}
};
} // namespace operators
} // namespace paddle
......@@ -30,8 +30,6 @@ class TopkOp : public framework::OperatorWithKernel {
"Output(Indices) of TopkOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(input_dims.size(), 2,
"Rank of TopK op's input must be 2.");
const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
......
......@@ -103,15 +103,16 @@ size_t CUDAPinnedMaxChunkSize() {
return CUDAPinnedMaxAllocSize() / 256;
}
#ifdef PADDLE_WITH_XBYAK
namespace jit {
#ifdef PADDLE_WITH_XBYAK
static Xbyak::util::Cpu cpu;
bool MayIUse(const cpu_isa_t cpu_isa) {
using namespace Xbyak::util; // NOLINT
switch (cpu_isa) {
case sse42:
return cpu.has(Cpu::tSSE42);
case avx:
return cpu.has(Cpu::tAVX);
case avx2:
return cpu.has(Cpu::tAVX2);
case avx512_common:
......@@ -134,8 +135,16 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
}
return false;
}
#else
bool MayIUse(const cpu_isa_t cpu_isa) {
if (cpu_isa == isa_any) {
return true;
} else {
return false;
}
}
#endif
} // namespace jit
#endif
} // namespace platform
} // namespace paddle
......@@ -37,12 +37,11 @@ size_t CUDAPinnedMinChunkSize();
//! Get the maximum chunk size for buddy allocator.
size_t CUDAPinnedMaxChunkSize();
#ifdef PADDLE_WITH_XBYAK
namespace jit {
typedef enum {
isa_any,
sse42,
avx,
avx2,
avx512_common,
avx512_core,
......@@ -55,7 +54,6 @@ typedef enum {
inline bool MayIUse(const cpu_isa_t cpu_isa);
} // namespace jit
#endif
} // namespace platform
} // namespace paddle
......@@ -66,10 +66,16 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_free); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm_batch); \
__macro(cblas_sdot); \
__macro(cblas_ddot); \
__macro(cblas_sscal); \
__macro(cblas_dscal); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(vsMul); \
__macro(vdMul); \
__macro(vsExp); \
__macro(vdExp); \
__macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
......@@ -43,9 +43,6 @@ void BindConstValue(pybind11::module* m) {
op_proto_and_checker_maker.def(
"kOpRoleVarAttrName",
framework::OpProtoAndCheckerMaker::OpRoleVarAttrName);
op_proto_and_checker_maker.def(
"kOpCreationCallstackAttrName",
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName);
}
} // namespace pybind
......
......@@ -19,6 +19,7 @@ import hashlib
import os
import errno
import shutil
import six
import sys
import importlib
import paddle.dataset
......@@ -94,6 +95,8 @@ def download(url, module_name, md5sum, save_name=None):
dl = 0
total_length = int(total_length)
for data in r.iter_content(chunk_size=4096):
if six.PY2:
data = six.b(data)
dl += len(data)
f.write(data)
done = int(50 * dl / total_length)
......
......@@ -35,6 +35,7 @@ import itertools
import functools
from .common import download
import tarfile
import six
import scipy.io as scio
from paddle.dataset.image import *
from paddle.reader import *
......@@ -45,10 +46,10 @@ from six.moves import cPickle as pickle
from six.moves import zip
__all__ = ['train', 'test', 'valid']
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118'
DATA_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/imagelabels.mat'
SETID_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
......@@ -120,7 +121,10 @@ def reader_creator(data_file,
file = file.strip()
batch = None
with open(file, 'rb') as f:
batch = pickle.load(f)
if six.PY2:
batch = pickle.load(f)
else:
batch = pickle.load(f, encoding='bytes')
data = batch['data']
labels = batch['label']
for sample, label in zip(data, batch['label']):
......
......@@ -18,7 +18,6 @@ import collections
import contextlib
import re
import six
import traceback
import numpy as np
......@@ -506,10 +505,6 @@ class Operator(object):
if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0:
del op_attrs[role_var_name]
callstack_var_name = op_maker.kOpCreationCallstackAttrName()
op_attrs[callstack_var_name] = list(
reversed(traceback.format_stack()))[1:]
if len(self.desc.type()) != 0:
return
if type is None:
......
......@@ -27,7 +27,6 @@ from . import utils
import random
from .. import unique_name
from functools import reduce
import warnings
__all__ = [
'fc',
......@@ -104,6 +103,7 @@ __all__ = [
'rank_loss',
'prelu',
'flatten',
'stack',
]
......@@ -2047,7 +2047,7 @@ def batch_norm(input,
param_attr(ParamAttr): The parameter attribute for Parameter `scale`.
bias_attr(ParamAttr): The parameter attribute for Parameter `bias`.
data_layout(string, default NCHW): NCHW|NHWC
in_place(bool, Default False): This argument is deprecated since 0.15.0.
in_place(bool, Default False): Make the input and output of batch norm reuse memory.
use_mkldnn(bool, Default false): ${use_mkldnn_comment}
name(string, Default None): A name for this layer(optional). If set None, the layer
will be named automatically.
......@@ -2069,10 +2069,6 @@ def batch_norm(input,
helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype()
if in_place:
raise warnings.warn("The argument in_place is deprecated since 0.15.0, "
"please do not set it True.")
input_shape = input.shape
if data_layout == 'NCHW':
channel_num = input_shape[1]
......@@ -2122,7 +2118,7 @@ def batch_norm(input,
saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
batch_norm_out = helper.create_tmp_variable(dtype)
batch_norm_out = input if in_place else helper.create_tmp_variable(dtype)
helper.append_op(
type="batch_norm",
......@@ -5522,3 +5518,17 @@ def flatten(x, axis=1, name=None):
outputs={'Out': out},
attrs={"axis": axis})
return out
def stack(x, axis=0):
helper = LayerHelper('stack', **locals())
axis = 0 if axis is None else axis
if not isinstance(x, list) and not isinstance(x, tuple):
x = [x]
out = helper.create_tmp_variable(x[0].dtype)
helper.append_op(
type='stack', inputs={'X': x}, outputs={'Y': out},
attrs={'axis': axis})
return out
......@@ -229,7 +229,7 @@ def img_conv_group(input,
use_mkldnn=use_mkldnn)
if conv_with_batchnorm[i]:
tmp = layers.batch_norm(input=tmp, act=conv_act)
tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True)
drop_rate = conv_batchnorm_drop_rate[i]
if abs(drop_rate) > 1e-5:
tmp = layers.dropout(x=tmp, dropout_prob=drop_rate)
......
......@@ -256,10 +256,7 @@ def main(net_type, use_cuda, is_local=True):
save_dirname = "image_classification_" + net_type + ".inference.model"
train(net_type, use_cuda, save_dirname, is_local)
# There is bug in fluid.InferenceTranspiler for VGG.
if net_type == "resnet":
infer(use_cuda, save_dirname)
infer(use_cuda, save_dirname)
class TestImageClassification(unittest.TestCase):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from test_fusion_lstm_op import fc, ACTIVATION
from test_softmax_op import stable_softmax
def attention_lstm(
x, # T x M
lod, # 1 x N
h0, # N x D
c0, # N x D
fcws, # (M+D) x 1, 1x1
fcbs, # 1 x 1, 1x1
w, # (M+D) x 4D
b, # 1 x 4D
act_gate,
act_cell,
act_cand):
T = sum(lod[0])
N = len(lod[0])
M = x.shape[1]
D = b.shape[1] / 4
assert T == x.shape[0]
assert len(fcws) == len(fcbs)
hidden = []
cell = []
start_offset = 0
for bid in range(N):
seq_len = lod[0][bid]
xi = np.copy(x[start_offset:start_offset + seq_len, :]).reshape(seq_len,
M)
prev_cell = np.copy(c0[bid]).reshape([1, D])
prev_hidden = np.copy(h0[bid]).reshape([1, D])
for step in range(seq_len):
expanded_cell = np.repeat(prev_cell, seq_len, axis=0)
tmp = np.concatenate((xi, expanded_cell), axis=1)
assert tmp.shape[0] == seq_len
assert tmp.shape[1] == M + D
for fcid in range(len(fcbs)):
tmp = fc(tmp, fcws[fcid], fcbs[fcid])
tmp = ACTIVATION['relu'](tmp)
tmp = np.reshape(tmp, (1, seq_len))
tmp = stable_softmax(tmp).reshape(seq_len, 1)
lstmx = xi * tmp # seq * M
lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M])
lstmin = np.concatenate((prev_hidden, lstmx), axis=1)
lstmout = fc(lstmin, w, b).reshape([1, 4 * D])
g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1)
g_f = act_gate(g_f).reshape([1, D])
g_i = act_gate(g_i).reshape([1, D])
g_o = act_gate(g_o).reshape([1, D])
cand = act_cand(cand).reshape([1, D])
cell_t = (prev_cell * g_f) + (g_i * cand)
hidden_t = g_o * act_cell(cell_t)
hidden.append(hidden_t.flatten())
cell.append(cell_t.flatten())
prev_cell = cell_t.reshape([1, D])
prev_hidden = hidden_t.reshape([1, D])
start_offset += seq_len
hidden = np.array(hidden).astype('float32').reshape([T, D])
cell = np.array(cell).astype('float32').reshape([T, D])
return hidden, cell
class TestAttentionLSTMOp(OpTest):
def set_conf(self):
pass
def setUp(self):
self.op_type = 'attention_lstm'
self.lod = [[3]]
self.M = 30
self.D = 15
self.has_initial_hidden = True
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.set_conf()
T = sum(self.lod[0])
bs = len(self.lod[0])
x = np.random.normal(size=(T, self.M)).astype('float32')
c0 = np.random.normal(size=(bs, self.D)).astype('float32')
if self.has_initial_hidden:
h0 = np.random.normal(size=(bs, self.D)).astype('float32')
else:
h0 = np.zeros((bs, self.D)).astype('float32')
fcw1 = np.random.normal(size=(self.M + self.D, 1)).astype('float32')
fcb1 = np.random.normal(size=(1, 1)).astype('float32')
fcw2 = np.random.normal(size=(1, 1)).astype('float32')
fcb2 = np.random.normal(size=(1, 1)).astype('float32')
# lstm weight and bias
w = np.random.normal(size=(self.M + self.D,
self.D * 4)).astype('float32')
b = np.random.normal(size=(1, self.D * 4)).astype('float32')
h, c = attention_lstm(x, self.lod, h0, c0, [fcw1, fcw2], [fcb1, fcb2],
w, b, ACTIVATION[self.act_gate],
ACTIVATION[self.act_cell],
ACTIVATION[self.act_cand])
self.inputs = {
'X': (x, self.lod),
'C0': c0,
'AttentionWeight': fcw1,
'AttentionBias': fcb1,
'AttentionScalar': fcw2,
'AttentionScalarBias': fcb2,
'LSTMWeight': w,
'LSTMBias': b
}
if self.has_initial_hidden:
self.inputs['H0'] = h0
self.outputs = {
'Hidden': (h, self.lod),
'Cell': (c, self.lod),
}
self.attrs = {
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
'candidate_activation': self.act_cand
}
def test_check_output(self):
self.check_output()
class TestAttentionOpNonInit(TestAttentionLSTMOp):
def set_conf(self):
self.has_initial_hidden = False
class TestAttentionOpAct(TestAttentionLSTMOp):
def set_conf(self):
self.M = 3
self.D = 2
self.act_gate = 'relu'
self.act_cell = 'tanh'
self.act_cand = 'sigmoid'
class TestAttentionOpMD1(TestAttentionLSTMOp):
def set_conf(self):
self.M = 36
self.D = 8
class TestAttentionOpMD2(TestAttentionLSTMOp):
def set_conf(self):
self.M = 8
self.D = 8
class TestAttentionOpMD3(TestAttentionLSTMOp):
def set_conf(self):
self.M = 15
self.D = 30
class TestAttentionOpBS1(TestAttentionLSTMOp):
def set_conf(self):
self.lod = [[5]]
self.M = 16
self.D = 32
class TestAttentionOpBS2(TestAttentionLSTMOp):
def set_conf(self):
self.lod = [[3, 6]]
class TestAttentionOpBS5(TestAttentionLSTMOp):
def set_conf(self):
self.lod = [[3, 2, 4, 7, 5]]
if __name__ == '__main__':
unittest.main()
......@@ -67,10 +67,7 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.output("Out"), ["mul.out"])
self.assertEqual(
set(mul_op.attr_names),
set([
"x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var",
"op_callstack"
]))
set(["x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var"]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import unittest
from multiprocessing import Process
import signal
import numpy
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layers.io import ListenAndServ
from paddle.fluid.layers.io import Recv
from paddle.fluid.layers.io import Send
from paddle.fluid.transpiler.details import program_to_code
class TestProgram2Code(unittest.TestCase):
def test_print(self):
place = fluid.CPUPlace()
self.init_serv(place)
self.init_client(place, 9123)
def init_serv(self, place):
main = fluid.Program()
with fluid.program_guard(main):
serv = ListenAndServ("127.0.0.1:0", ["X"], optimizer_mode=False)
with serv.do():
out_var = main.global_block().create_var(
name="scale_0.tmp_0",
psersistable=True,
dtype="float32",
shape=[32, 32])
x = layers.data(
shape=[32, 32],
dtype='float32',
name="X",
append_batch_size=False)
fluid.initializer.Constant(value=1.0)(x, main.global_block())
layers.scale(x=x, scale=10.0, out=out_var)
program_to_code(main)
def init_client(self, place, port):
main = fluid.Program()
with fluid.program_guard(main):
x = layers.data(
shape=[32, 32],
dtype='float32',
name='X',
append_batch_size=False)
fluid.initializer.Constant(value=2.3)(x, main.global_block())
get_var = main.global_block().create_var(
name="scale_0.tmp_0", # server side var
dtype="float32",
persistable=False,
shape=[32, 32])
fluid.initializer.Constant(value=2.3)(get_var, main.global_block())
Send("127.0.0.1:%d" % port, [x])
o = Recv("127.0.0.1:%d" % port, [get_var])
program_to_code(main)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from op_test import OpTest
import numpy as np
import unittest
class TestStackOpBase(OpTest):
def initDefaultParameters(self):
self.num_inputs = 4
self.input_dim = (5, 6, 7)
self.axis = 0
self.dtype = 'float32'
def initParameters(self):
pass
def get_x_names(self):
x_names = []
for i in range(self.num_inputs):
x_names.append('x{}'.format(i))
return x_names
def setUp(self):
self.initDefaultParameters()
self.initParameters()
self.op_type = 'stack'
self.x = []
for i in range(self.num_inputs):
self.x.append(
np.random.random(size=self.input_dim).astype(self.dtype))
tmp = []
x_names = self.get_x_names()
for i in range(self.num_inputs):
tmp.append((x_names[i], self.x[i]))
self.inputs = {'X': tmp}
self.outputs = {'Y': np.stack(self.x, axis=self.axis)}
self.attrs = {'axis': self.axis}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(self.get_x_names(), 'Y')
class TestStackOp1(TestStackOpBase):
def initParameters(self):
self.num_inputs = 16
class TestStackOp2(TestStackOpBase):
def initParameters(self):
self.num_inputs = 20
class TestStackOp3(TestStackOpBase):
def initParameters(self):
self.axis = -1
class TestStackOp4(TestStackOpBase):
def initParameters(self):
self.axis = -4
class TestStackOp5(TestStackOpBase):
def initParameters(self):
self.axis = 1
class TestStackOp6(TestStackOpBase):
def initParameters(self):
self.axis = 3
if __name__ == '__main__':
unittest.main()
......@@ -16,6 +16,9 @@ from __future__ import print_function
import six
from paddle.fluid import core
import paddle
def delete_ops(block, ops):
try:
......@@ -39,3 +42,133 @@ def find_op_by_output_arg(block, arg_name):
if arg_name in op.output_arg_names:
return index
return -1
def get_indent_space(indent, space_num=4):
ret = ""
for i in range(0, indent * space_num):
ret += " "
return ret
def variable_to_code(var):
"""
Get readable codes of fluid variable.
Args:
var: A fluid operator.
Returns:
string: The formatted string.
"""
var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})".\
format(i="{", e="}", name=var.name, type=var.type, shape=var.shape, dtype=var.dtype)
if type(var) == paddle.fluid.framework.Parameter:
if var.trainable:
var_str = "trainable parameter " + var_str
else:
var_str = "parameter " + var_str
else:
var_str = "var " + var_str
if var.persistable:
var_str = "persist " + var_str
return var_str
def op_to_code(op):
"""
Get readable codes of fluid operator.
Args:
op: A fluid operator.
Returns:
string: The foramtted string.
"""
outputs_str = "{"
for i in range(0, len(op.output_names)):
outputs_str += "{name}=".format(name=op.output_names[i])
o = op.output(op.output_names[i])
outputs_str += "{value}".format(value=o)
if i != len(op.output_names) - 1:
outputs_str += ", "
outputs_str += "}"
inputs_str = "{"
for i in range(0, len(op.input_names)):
inputs_str += "{name}=".format(name=op.input_names[i])
o = op.input(op.input_names[i])
inputs_str += "{value}".format(value=o)
if i != len(op.input_names) - 1:
inputs_str += ", "
inputs_str += "}"
attrs_str = ""
for i in range(0, len(op.attr_names)):
name = op.attr_names[i]
attr_type = op.desc.attr_type(name)
if attr_type == core.AttrType.BLOCK:
a = "{name} = block[{value}]".format(
name=name, type=attr_type, value=op.block_attr_id(name))
attrs_str += a
continue
if attr_type == core.AttrType.BLOCKS:
a = "{name} = blocks{value}".format(
name=name, type=attr_type, value=op.blocks_attr_ids(name))
attrs_str += a
continue
a = "{name} = {value}".format(
name=name, type=attr_type, value=op.desc.attr(name))
attrs_str += a
if i != len(op.attr_names) - 1:
attrs_str += ", "
if outputs_str != "{}":
op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\
format(outputs = outputs_str, op_type=op.type, inputs=inputs_str, attrs=attrs_str)
else:
op_str = "{op_type}(inputs={inputs}, {attrs})".\
format(op_type=op.type, inputs=inputs_str, attrs=attrs_str)
return op_str
def program_to_code(prog):
"""
Print readable codes of fluid program.
Args:
prog : A fluid program.
An example result like bellow:
https://github.com/PaddlePaddle/Paddle/pull/12673
"""
indent = 0
block_idx = 0
for block in prog.blocks:
print("{0}{1} // block {2}".format(
get_indent_space(indent), '{', block_idx))
indent += 1
# sort all vars
all_vars = sorted(block.vars.iteritems(), key=lambda x: x[0])
for var in all_vars:
print("{}{}".format(
get_indent_space(indent), variable_to_code(var[1])))
if len(all_vars) > 0:
print("")
for op in block.ops:
print("{}{}".format(get_indent_space(indent), op_to_code(op)))
indent -= 1
print("{0}{1}".format(get_indent_space(indent), '}'))
block_idx += 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册