提交 601c1a35 编写于 作者: W wangmeng28

Merge remote-tracking branch 'upstream/develop' into factorization_machine_layer

......@@ -127,6 +127,7 @@ include(external/warpctc) # download, build, install warpctc
include(external/any) # download libn::any
include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11
include(external/nccl)
include(cudnn) # set cudnn libraries, must before configure
include(configure) # add paddle env configuration
......@@ -159,7 +160,7 @@ set(EXTERNAL_LIBS
if(WITH_GPU)
list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
if(NOT WITH_DSO)
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY})
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY} ${NCCL_LIBRARY})
endif(NOT WITH_DSO)
endif(WITH_GPU)
......
......@@ -62,11 +62,11 @@ else()
FIND_PACKAGE(CUDA REQUIRED)
if(${CUDA_VERSION_MAJOR} VERSION_LESS 7)
message(FATAL_ERROR "Paddle need CUDA >= 7.0 to compile")
message(FATAL_ERROR "Paddle needs CUDA >= 7.0 to compile")
endif()
if(NOT CUDNN_FOUND)
message(FATAL_ERROR "Paddle need cudnn to compile")
message(FATAL_ERROR "Paddle needs cudnn to compile")
endif()
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${SIMD_FLAG}")
......
INCLUDE(ExternalProject)
SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
if(WITH_DSO)
# If we use DSO, we do not build nccl, just download the dependencies
set(NCCL_BUILD_COMMAND "")
set(NCCL_INSTALL_COMMAND "")
set(NCCL_INSTALL_DIR "")
else()
# otherwise, we build nccl and link it.
set(NCCL_BUILD_COMMAND "make -j 8")
set(NCCL_INSTALL_COMMAND "make install")
SET(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
endif()
ExternalProject_Add(
extern_nccl
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git"
GIT_TAG "v1.3.4-1"
PREFIX "${NCCL_SOURCE_DIR}"
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND "${NCCL_BUILD_COMMAND}"
INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}"
INSTALL_DIR "${NCCL_INSTALL_DIR}"
TEST_COMMAND ""
)
if (WITH_DSO)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";")
add_library(nccl STATIC ${dummyfile})
else()
add_library(nccl INTERFACE)
endif()
else()
ADD_LIBRARY(nccl STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET nccl PROPERTY IMPORTED_LOCATION
${NCCL_INSTALL_DIR}/lib/libnccl.a)
endif()
add_dependencies(nccl extern_nccl)
LIST(APPEND external_project_dependencies nccl)
......@@ -87,11 +87,8 @@ class OpInfoMap {
}
}
template <typename Callback>
void IterAllInfo(Callback callback) {
for (auto& it : map_) {
callback(it.first, it.second);
}
const std::unordered_map<std::string, const OpInfo>& map() const {
return map_;
}
private:
......
......@@ -18,6 +18,10 @@ limitations under the License. */
namespace paddle {
namespace framework {
VarDesc::VarType VarDescBind::GetType() const { return desc_.type(); }
void VarDescBind::SetType(VarDesc::VarType type) { desc_.set_type(type); }
void VarDescBind::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
}
......
......@@ -75,9 +75,9 @@ class VarDescBind {
int32_t GetLodLevel() const;
VarDesc::VarType GetType() const { return desc_.type(); }
VarDesc::VarType GetType() const;
void SetType(VarDesc::VarType type) { desc_.set_type(type); }
void SetType(VarDesc::VarType type);
bool Persistable() const { return desc_.persistable(); }
......
......@@ -339,11 +339,15 @@ private:
* clear all grad
*/
void clearGrads() {
if (output_.grad) {
output_.grad->zeroMem();
}
for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
if (outputOtherDevice_[i].grad) {
outputOtherDevice_[i].grad->zeroMem();
}
}
}
/**
* Set deviceId of the params used in this layer.
......
......@@ -115,7 +115,8 @@ set(DEPS_OPS
softmax_with_cross_entropy_op
sum_op
pool_op
pool_with_index_op)
pool_with_index_op
lstm_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
......@@ -126,6 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -114,7 +114,7 @@ class GemmConv2DKernel : public framework::OpKernel<T> {
// im2col
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
im2col(context.device_context(), in_slice, col, strides[0], strides[1],
paddings[0], paddings[1]);
paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
......@@ -213,7 +213,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
col2im(context.device_context(), in_grad_slice, col, strides[0],
strides[1], paddings[0], paddings[1]);
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
}
}
}
......@@ -235,7 +236,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
im2col(context.device_context(), in_slice, col, strides[0],
strides[1], paddings[0], paddings[1]);
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
// gemm
Tensor filter_grad_slice =
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv2dtranspose_op.h"
namespace paddle {
namespace operators {
void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of Conv2DTransposeOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of Conv2DTransposeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of Conv2DTransposeOp should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_EQ(paddings[i], 0,
"No Padding allowed in conv transpose op.");
}
PADDLE_ENFORCE_EQ(in_dims.size(), 4,
"Conv2DTransposeOp input should be 4-D tensor.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
"Conv2DTransposeOp filter should be 4-D tensor.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"input and kernel input dimension should be equal.");
auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2];
auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3];
ctx->SetOutputDim("Output",
{in_dims[0], filter_dims[1], output_height, output_width});
}
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
"(Tensor) The input tensor of convolution transpose operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of input channels, H and W is the height and width of image.");
AddInput("Filter",
"(Tensor) The filter tensor of convolution transpose operator."
"The format of the filter tensor is CMHW, where C is the number of "
"output image channels, M is the number of input image channels, "
"H and W is height and width of filter. "
"We enforce groups number == 1 and padding == 0 in "
"convolution transpose Scenario.");
AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator."
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides",
"strides of convolution transpose operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"paddings of convolution transpose operator.")
.SetDefault({0, 0});
AddComment(R"DOC(
The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
)DOC");
}
void Conv2DTransposeOpGrad::InferShape(
framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
}
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp,
ops::Conv2DTransposeOpMaker, conv2dtranspose_grad,
ops::Conv2DTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv2dtranspose,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2dtranspose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv2dtranspose_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
conv2dtranspose,
ops::GemmConv2DTransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv2dtranspose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
};
class Conv2DTransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override;
};
class Conv2DTransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override;
};
template <typename Place, typename T>
class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
// The filter will be reshaped, so it should not be constant pointer
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
const int batch_size = input->dims()[0];
const int m = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int k_h = filter.dims()[2];
const int k_w = filter.dims()[3];
const int c = output->dims()[1]; // output channels
const int o_h = output->dims()[2];
const int o_w = output->dims()[3];
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
// use col_shape in the im2col and col2im calculation
DDim col_shape = {c, k_h, k_w, h, w};
// use col_matrix_shape in the gemm calculation
DDim col_matrix_shape = {c * k_h * k_w, h * w};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix;
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
DDim output_shape = {c, o_h, o_w};
DDim input_matrix_shape = {m, h * w};
DDim filter_matrix_shape = {m, c * k_h * k_w};
filter.Resize(filter_matrix_shape);
// convolution transpose: gemm + col2im (similar to conv-backward on input)
output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) {
// batch with size (M, h * w)
Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// filter size: (M, c * k_h * k_w)
// output size: (c, o_h, o_w)
Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);
// col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w)
math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix, T(0.0));
col2im(context.device_context(), output_batch, col, strides[0],
strides[1], 0, 0, 0, 0);
}
}
};
template <typename Place, typename T>
class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
// For filter, we do not use const pointer b/c we will do reshape,
// but we should avoid modifying its value.
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
const int batch_size = input->dims()[0];
const int m = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int k_h = filter.dims()[2];
const int k_w = filter.dims()[3];
const int c = output_grad->dims()[1]; // output channels
const int o_h = output_grad->dims()[2];
const int o_w = output_grad->dims()[3];
// Only im2col functor required for bp to get to the right shape
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
im2col;
// use col_shape in the im2col and col2im calculation
DDim col_shape = {c, k_h, k_w, h, w};
// use col_matrix_shape in the gemm calculation
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
DDim output_shape = {c, o_h, o_w};
DDim input_matrix_shape = {m, h * w};
DDim filter_matrix_shape = {m, c * k_h * k_w};
filter.Resize(filter_matrix_shape);
// convolution transpose grad on input:
// im2col + gemm (similar to conv-forward)
// input need to compute gradient
if (input_grad) {
Tensor col_matrix;
col_matrix.ShareDataWith(col);
DDim col_matrix_shape = {c * k_h * k_w, h * w};
col_matrix.Resize(col_matrix_shape);
input_grad->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) {
// batch with size (c, o_h * o_w)
Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape);
// filter of size (m, c * k_h * k_w)
// batch with size (m, h, w)
Tensor input_grad_batch =
input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
// im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm: dx = filter * dy
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h)
math::matmul<Place, T>(context.device_context(), filter, false,
col_matrix, false, T(1.0), &input_grad_batch,
T(0.0));
}
}
// filter gradient required
if (filter_grad) {
Tensor col_matrix_f;
col_matrix_f.ShareDataWith(col);
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
col_matrix_f.Resize(col_matrix_shape_f);
filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
auto t = framework::EigenVector<T>::Flatten(filter_grad_);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; ++i) {
// batch with size (c, o_h, o_w)
Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape);
// input batch
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// im2col: (c * h * w, k_h * k_w)
im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm: d_filter = x * y_grad^T
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h)
math::matmul<Place, T>(context.device_context(), in_batch, false,
col_matrix_f, true, T(1.0), &filter_grad_,
T(1.0));
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -59,7 +59,8 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker {
"The input should be a k-D tensor(k > 0 and k < 7)");
AddInput("Y",
"The input used as reference for cropping"
" with the same dimension as X. ");
" with the same dimension as X. ")
.AsDispensable();
AddOutput("Out",
"The output of crop op "
"with the same dimension as X.");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
class FCOp : public NetOp {
public:
FCOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
PADDLE_ENFORCE(!Inputs("X").empty(),
"Inputs(X) of FCOp should not be null.");
PADDLE_ENFORCE(!Inputs("W").empty(),
"Inputs(W) of FCOp should not be null.");
PADDLE_ENFORCE(!Outputs("MulOut").empty(),
"Outputs(MulOut) of FCOp should not be null.");
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
"Output(Out) of FCOp should not be null.");
auto x = Inputs("X");
auto w = Inputs("W");
auto mul_out = Outputs("MulOut");
PADDLE_ENFORCE_EQ(
x.size(), w.size(),
"The size of inputs X(%d) should be the same as that of weights W(%d).",
x.size(), w.size());
PADDLE_ENFORCE_EQ(mul_out.size(), x.size(),
"The size of intermediate mul_out(%d) should be the same "
"as that of inputs X(%d).",
mul_out.size(), x.size());
size_t n = x.size();
PADDLE_ENFORCE_GE(n, static_cast<size_t>(1),
"The size of inputs X(%d) should be no less than 1.", n);
auto x_num_col_dims = Attr<std::vector<int>>("xNumColDims");
// Set all values or set no values (use the default value)
if (!x_num_col_dims.empty()) {
PADDLE_ENFORCE_EQ(x_num_col_dims.size(), n,
"The size of attribute xNumColDims(%d) should be the "
"same as that of inputs X(%d).",
x_num_col_dims.size(), n);
} else {
x_num_col_dims.resize(n);
for (size_t i = 0; i < n; i++) {
x_num_col_dims[i] = 1;
}
}
// mul_out[i] = X[i] * W[i]
for (size_t i = 0; i < n; i++) {
framework::AttributeMap mul_attr;
mul_attr["x_num_col_dims"] = static_cast<int>(x_num_col_dims[i]);
mul_attr["y_num_col_dims"] = static_cast<int>(1);
AppendOp(
framework::OpRegistry::CreateOp("mul", {{"X", {x[i]}}, {"Y", {w[i]}}},
{{"Out", {mul_out[i]}}}, mul_attr));
}
// sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1]
auto sum_out = mul_out[0];
if (n > 1) {
PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName,
"Output(SumOut) of FCOp should not be null when the "
"size of Inputs(X) > 1.");
sum_out = Output("SumOut");
AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}},
{{"Out", {sum_out}}}, {}));
} else {
if (Output("SumOut") != framework::kEmptyVarName) {
this->Rename(Output("SumOut"), framework::kEmptyVarName);
}
}
// add_out = sum_out + b
auto b = Input("B");
auto add_out = sum_out;
if (b != framework::kEmptyVarName) {
PADDLE_ENFORCE_NE(
Output("AddOut"), framework::kEmptyVarName,
"Output(AddOut) of FCOp should not be null when Input(B) is set.");
add_out = Output("AddOut");
AppendOp(framework::OpRegistry::CreateOp(
"elementwise_add", {{"X", {sum_out}}, {"Y", {Input("B")}}},
{{"Out", {add_out}}}, {}));
} else {
if (Output("AddOut") != framework::kEmptyVarName) {
this->Rename(Output("AddOut"), framework::kEmptyVarName);
}
}
auto activation = Attr<std::string>("activation");
AppendOp(framework::OpRegistry::CreateOp(activation, {{"X", {add_out}}},
{{"Y", {Output("Out")}}}, {}));
CompleteAddOp(false);
}
};
class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FCOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(A vector of Tensors) each input Tensor can be of arbitrary "
"dimension, and will be reshaped to a 2-D matrix of size "
"(minibatch, number_of_input_features) according to attribute "
"xNumColDims.")
.AsDuplicable();
AddInput("W",
"(A vector of Tensors) the weights of FC operator, a "
"vector of 2-D matrix of size "
"(number_of_input_features, number_of_neurons).")
.AsDuplicable();
AddInput("B",
"(Tensor) the bias of FC operator, a 1-D vector of size "
"number_of_neurons.");
AddOutput("Out",
"(Tensor) the activated output matrix of FC operator, a 2-D "
"matrix of size (minibatch, number_of_neurons).");
AddOutput("MulOut",
"(A vector of Tensors) the intermediate outputs of FC operator, "
"each Tensor saving the product of X_i * W_i.")
.AsIntermediate()
.AsDuplicable();
AddOutput(
"SumOut",
"(Tensor) the intermediate output of FC operator, "
"saving the sum of the products of X and W, that is sum{X_i * W_i}.")
.AsIntermediate();
AddOutput("AddOut",
"(Tensor) the non-actived output of FC operator, "
"saving sum{X_i * W_i} + B.")
.AsIntermediate();
AddAttr<std::string>(
"activation",
"(string, default identity) the activation type of FC operator.")
.SetDefault("identity")
.InEnum({"identity", "sigmoid", "softmax"});
AddAttr<std::vector<int>>(
"xNumColDims",
"(std::vector<int>) The inputs Tensors of FC operator can be of "
"more than 2 dimensions. In that case, each input Tensor `X_i` will be "
"reshaped to a 2-D matrix. The matrix's first dimension "
"(the length of column) will be the product of `X_i`'s last "
"`xNumColDims_i` dimensions, that is "
"`X_i.dims[0] x ... x X_i.dims[xNumColDims_i - 1]`. "
"The matrix's second dimension (the length of row) will be the product "
"of `X_i`'s first `rank - xNumColDims_i` dimensions, that is "
"`X_i.dims[xNumColDims_i] x ... x X_i.dims[rank - 1]`)")
.SetDefault(std::vector<int>{});
AddComment(R"DOC(
Fully Connected Operator, known as Fully Connected Layer or Inner Product Layer
in Convolutional Neural Networks. Neurons in a fully connected layer have
full connections to all activations in the previous layer.
It computes an inner product of a set of
learned weights with a matrix multiplication followed by a bias offset
(optionally).
Equation:
Out = Act(sum_n{X_i * W_i} + B)
where X_i is Tensor that will be reshaped to a 2-D matrix of size (M x K),
usually M is the minibatch size and K is the number of input features.
W_i is a 2-D matrix of size (K x N), where N means the number of neurons
in the fully connected layer. B is a 1-D vector of size N.
Thus, the output Out is a 2-D matrix of size (M x N).
Activation type can be set to `identity` (default), `sigmoid` or `softmax`.
All the inputs can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD with first input (`X[0]`).
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fc, ops::FCOp, ops::FCOpMaker);
......@@ -54,8 +54,7 @@ class GRUUnitOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
auto bias = Input("Bias");
if (bias != framework::kEmptyVarName) {
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
......@@ -89,7 +88,8 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
"weights of output candidate with shape [frame_size, frame_size]");
AddInput("Bias",
"(Tensor) Bias vector with shape [1, frame_size * 3] concating "
"bias of the update gate, reset gate and output candidate.");
"bias of the update gate, reset gate and output candidate.")
.AsDispensable();
AddOutput("Gate",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
"output of update gate, reset gate and output candidate")
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/net_op.h"
#include "paddle/operators/scale_op.h"
namespace paddle {
namespace operators {
// The identity operator is an alias of the scale operator. This is also an
// example for creating an alias for an existing operator.
template <typename AttrType>
class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IdentityOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of identity operator.");
AddOutput("Y", "The output tensor of identity operator.");
AddComment(R"DOC(
The identity operator is an alias of the scale operator
with the attribute scale fixed to 1.0.
)DOC");
}
};
template <typename AttrType>
class IdentityOp : public NetOp {
public:
IdentityOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName,
"Input(X) of IdentityOp should not be null.");
PADDLE_ENFORCE_NE(Output("Y"), framework::kEmptyVarName,
"Output(Y) of IdentityOp should not be null.");
AppendOp(framework::OpRegistry::CreateOp(
"scale", {{"X", {Input("X")}}}, {{"Out", {Output("Y")}}},
{{"scale", static_cast<AttrType>(1)}}));
CompleteAddOp(false);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(identity, ops::IdentityOp<float>,
ops::IdentityOpMaker<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
class InterpOp : public NetOp {
public:
InterpOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName,
"Input(X) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Input("Y"), framework::kEmptyVarName,
"Input(Y) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName,
"Input(W) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("SubOut"), framework::kEmptyVarName,
"Output(SubOut) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName,
"Output(MulOut) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
"Output(Out) of InterpOp should not be null.");
// SubOut = X - Y
auto x = Input("X");
auto y = Input("Y");
auto sub_out = Output("SubOut");
AppendOp(framework::OpRegistry::CreateOp(
"elementwise_sub", {{"X", {x}}, {"Y", {y}}}, {{"Out", {sub_out}}}, {}));
// MulOut = SubOut * W = (X - Y) * W
auto w = Input("W");
auto mul_out = Output("MulOut");
AppendOp(framework::OpRegistry::CreateOp(
"elementwise_mul", {{"X", {sub_out}}, {"Y", {w}}}, {{"Out", {mul_out}}},
{{"axis", 0}}));
// Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W)
AppendOp(framework::OpRegistry::CreateOp("elementwise_add",
{{"X", {mul_out}}, {"Y", {y}}},
{{"Out", {Output("Out")}}}, {}));
CompleteAddOp(false);
}
};
class InterpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor), 2-D Matrix of shape [batch_size, data_dim]"
"containing data samples, the first input of interp_op");
AddInput("Y",
"(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`"
"containing data samples, the second input of interp_op");
AddInput("W",
"(Tensor), 1-D Vector of shape [batch_size],"
"the interpolated values in the half-open interval [0.0, 1.0)");
AddOutput("SubOut",
"(Tensor), the intermediate subtraction outputs, saving X - Y.")
.AsIntermediate();
AddOutput("MulOut",
"(Tensor), the intermediate multiplication outputs,"
"saving the elementwise multiplication of (X - Y) and W.")
.AsIntermediate();
AddOutput("Out",
"(Tensor), the output of interp_op, same shape with X,"
"returns the first-dimensional piecewise linear interpolant "
"between X and Y");
AddComment(R"DOC(
Linear Interpolation with two inputs, used in NEURAL TURING MACHINE.
Equation:
Out.row[i] = X.row[i] * W[i] + Y.row[i] * (1 - W[i])
= (X.row[i] - Y.row[i]) * W[i] + Y.row[i]
Example:
X = [[1,2],[3,4]],
Y = [[2,1],[4,3]],
W = [0.3, 0.4]
Then, Out = [[1.7,1.3],[3.6,3.4]]
where 1.7 = 1*0.3+2*(1-0.3),
1.3 = 2*0.3+1*(1-0.3),
3.6 = 3*0.4+4*(1-0.4),
3.4 = 4*0.4+3*(1-0.4)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(interp, ops::InterpOp, ops::InterpOpMaker);
......@@ -32,6 +32,9 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out");
}
......@@ -53,7 +56,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
" which is a learnable parameter.");
AddInput("Ids",
"An input with type int32 or int64"
"contains the ids to be looked up in W.");
"contains the ids to be looked up in W."
"Ids must be a column vector with rank = 2."
"The 2nd dimension size must be 1");
AddOutput("Out", "The lookup results, which have the same type with W.");
AddComment(R"DOC(
This operator is used to perform lookups on the parameter W,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lstm_op.h"
namespace paddle {
namespace operators {
class LSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
int frame_size = x_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"The rank of Input(Weight) should be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], frame_size,
"The first dimension of Input(Weight) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
"The second dimension of Input(Weight) "
"should be 4 * %d.",
frame_size);
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
if (ctx->Attrs().Get<bool>("usePeepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection",
frame_size);
} else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection",
frame_size);
}
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
ctx->SetOutputDim("Cell", {x_dims[0], frame_size});
ctx->SetOutputDim("BatchGate", x_dims);
ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell");
}
};
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where, T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.");
AddInput("C0",
"(Tensor, optional) the initial cell state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size. `H0` and `C0` can be NULL but only at the same time");
AddInput("Weight",
"(Tensor) the learnable hidden-hidden weights."
" - The shape is (D x 4D), where D is the hidden size. "
" - Weight = {W_ch, W_ih, W_fh, W_oh}");
AddInput("Bias",
"(Tensor) the learnable weights, which contains two parts: "
"input-hidden bias weight and peephole connections weight if "
"setting `usePeepholes` True. "
"1. `usePeepholes = False` "
" - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `usePeepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which "
"was also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence "
"in the raw input.")
.AsIntermediate();
AddOutput("Hidden",
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddAttr<bool>("usePeepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
.SetDefault(true);
AddAttr<bool>("isReverse",
"(bool, defalut: False) "
"whether to compute reversed LSTM.")
.SetDefault(false);
AddAttr<std::string>(
"gateActivation",
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.")
.SetDefault("sigmoid");
AddAttr<std::string>("cellActivation",
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh");
AddAttr<std::string>("candidateActivation",
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by default.")
.SetDefault("tanh");
AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator
The defalut implementation is diagonal/peephole connection [1], the formula is
as follows
i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i)
f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f)
\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o)
c_t = f_t ⊙ c_{t-1} + i_t ⊙ \tilde{c_t}
h_t = o_t ⊙ act_h(c_t)
where the W terms denote weight matrices (e.g. \f$W_{xi}\f$ is the matrix
of weights from the input gate to the input), \f$W_{ic}, W_{fc}, W_{oc}\f$
are diagonal weight matrices for peephole connections. In our implenmention,
We use vectors to reprenset these diagonal weight matrices. The b terms
denote bias vectors (\f$b_i\f$ is the input gate bias vector), \f$\sigma\f$
is the non-line actications, such as logistic sigmoid function, and
\f$i, f, o\f$ and \f$c\f$ are respectively the input gate, forget gate,
output gate and cell activation vectors, all of which are the same size as
the cell output activation vector \f$h\f$.
The ⊙ is the element-wise product of the vectors, \f$act_g\f$ and \f$act_h\f$
are the cell input and cell output activation functions, `tanh` is usually
used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Set `usePeepholes` False to disable peephole connection [2]. The formula
is omitted here.
@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
operations on the input x_{t} were NOT included in this operator.
Users can choose to use fully-connect operator before LSTM operator.
[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory
recurrent neural network architectures for large scale acoustic modeling.
INTERSPEECH, 2014.
[2] S. Hochreiter and J. Schmidhuber. Long Short-Term Memory.
Neural Computation, 9(8):1735-1780, 1997.
)DOC");
}
};
class LSTMGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(Hidden@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
"Input(Cell@GRAD) should not be null");
ctx->SetOutputDim(framework::GradVarName("Weight"),
ctx->GetInputDim("Weight"));
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(lstm, ops::LSTMOp, ops::LSTMOpMaker, lstm_grad, ops::LSTMGradOp);
REGISTER_OP_CPU_KERNEL(lstm, ops::LSTMKernel<paddle::platform::CPUPlace, float>,
ops::LSTMKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(lstm_grad,
ops::LSTMGradKernel<paddle::platform::CPUPlace, float>,
ops::LSTMGradKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/lstm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lstm, ops::LSTMKernel<paddle::platform::GPUPlace, float>,
ops::LSTMKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(lstm_grad,
ops::LSTMGradKernel<paddle::platform::GPUPlace, float>,
ops::LSTMGradKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
using framework::LoDTensor;
using framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::LoDTensor>("Input");
auto* weight = ctx.Input<framework::Tensor>("Weight");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* batch_gate = ctx.Output<framework::LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace());
auto* hidden_out = ctx.Output<framework::LoDTensor>("Hidden");
hidden_out->mutable_data<T>(ctx.GetPlace());
auto* cell_out = ctx.Output<framework::LoDTensor>("Cell");
cell_out->mutable_data<T>(ctx.GetPlace());
// Now the function ShareLoD in InferShape is not implemented.
// So copy LoD here.
ctx.ShareLoD("Input", "Hidden");
ctx.ShareLoD("Input", "Cell");
bool is_reverse = ctx.Attr<bool>("isReverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
to_batch(ctx.device_context(), *input, *batch_gate, is_reverse);
auto in_dims = input->dims();
int frame_size = static_cast<int>(in_dims[1] / 4);
framework::DDim dims({in_dims[0], frame_size});
if (bias) {
Eigen::array<int, 2> extents({{1, 4 * frame_size}});
Eigen::array<int, 2> offsets({{0, 0}});
auto b = EigenMatrix<T>::From(*bias);
auto gate = EigenMatrix<T>::From(*batch_gate);
gate.device(ctx.GetEigenDevice<Place>()) =
gate +
b.slice(offsets, extents)
.reshape(Eigen::array<int, 2>({{1, frame_size * 4}}))
.broadcast(
Eigen::array<int, 2>({{static_cast<int>(in_dims[0]), 1}}));
}
math::LstmMetaValue<T> lstm_value;
T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
lstm_value.prevStateValue = nullptr;
framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act;
batch_out.mutable_data<T>(dims, ctx.GetPlace());
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto gate_act = ctx.Attr<std::string>("gateActivation");
auto cell_act = ctx.Attr<std::string>("cellActivation");
auto cand_act = ctx.Attr<std::string>("candidateActivation");
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor out_t = batch_out.Slice(bstart, bend);
Tensor cell_t = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend);
int cur_batch_size = bend - bstart;
if (n != 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false,
*weight, false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
}
// else if : FIXME support the initial hidden and cell
lstm_value.gateValue = gate_t.data<T>();
lstm_value.outputValue = out_t.data<T>();
lstm_value.stateValue = cell_t.data<T>();
lstm_value.stateActiveValue = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<Place, T>::compute(ctx.device_context(), lstm_value,
frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
lstm_value.prevStateValue = lstm_value.stateValue;
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_out.set_lod(batch_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden
to_seq(ctx.device_context(), batch_out, *hidden_out);
batch_cell.set_lod(batch_gate->lod());
// restore the output cell state in LoDTensor from the batch cell
to_seq(ctx.device_context(), batch_cell, *cell_out);
}
};
template <typename Place, typename T>
class LSTMGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
};
} // namespace operators
} // namespace paddle
......@@ -19,7 +19,6 @@
namespace paddle {
namespace operators {
using framework::LoDTensor;
using framework::Tensor;
template <typename T>
......
add_subdirectory(detail)
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor)
......@@ -7,6 +9,8 @@ if(WITH_GPU)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
......@@ -14,6 +18,8 @@ else()
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
......@@ -22,8 +22,6 @@ namespace {
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
const int N, const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
......
if(WITH_AVX)
cc_library(activation_functions SRCS hl_cpu_functions.cc hl_avx_functions.cc)
else()
cc_library(activation_functions SRCS hl_cpu_functions.cc)
endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_ACTIVATION_FUNCTIONS_H_
#define HL_ACTIVATION_FUNCTIONS_H_
#include "hl_functions.h"
#include "paddle/operators/math/lstm_compute.h"
/**
* Active functions: sigmoid, relu, tanh and linear.
*/
#define FLOAT_ACTIVE_FUNCTION \
{ \
hppl::typef::sigmoid, hppl::typef::relu, hppl::typef::tanh, \
hppl::typef::linear \
}
#define DOUBLE_ACTIVE_FUNCTION \
{ \
hppl::typed::sigmoid, hppl::typed::relu, hppl::typed::tanh, \
hppl::typed::linear \
}
#define AVX_ACTIVE_FUNCTION \
{ hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear }
namespace hppl {
using activation_mode_t = paddle::operators::math::activation_mode_t;
/**
* Hppl supports sigmoid, relu, tanh, linear active functions
* for neural networks' forward and backward activation.
*/
template <class T>
class Active {
public:
typedef T (*forward)(T);
typedef T (*backward)(T, T);
};
template <typename T>
struct ForwardActType;
template <>
struct ForwardActType<float> {
using type = Active<float>::forward;
};
template <>
struct ForwardActType<double> {
using type = Active<double>::forward;
};
template <typename T>
struct BackwardActType;
template <>
struct BackwardActType<float> {
using type = Active<float>::backward;
};
template <>
struct BackwardActType<double> {
using type = Active<double>::backward;
};
#ifdef __NVCC__
namespace gpu {
static __device__ Active<float>::forward forward[] = FLOAT_ACTIVE_FUNCTION;
static __device__ Active<float>::backward backward[] = FLOAT_ACTIVE_FUNCTION;
static __device__ Active<double>::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION;
static __device__ Active<double>::backward backward_d[] =
DOUBLE_ACTIVE_FUNCTION;
template <typename T>
struct ForwardAct {
__device__ typename ForwardActType<T>::type operator()(
activation_mode_t type);
};
template <>
struct ForwardAct<float> {
__device__ ForwardActType<float>::type operator()(activation_mode_t type) {
return forward[type];
}
};
template <>
struct ForwardAct<double> {
__device__ ForwardActType<double>::type operator()(activation_mode_t type) {
return forward_d[type];
}
};
template <typename T>
struct BackwardAct {
__device__ typename BackwardActType<T>::type operator()(
activation_mode_t type);
};
template <>
struct BackwardAct<float> {
__device__ BackwardActType<float>::type operator()(activation_mode_t type) {
return backward[type];
}
};
template <>
struct BackwardAct<double> {
__device__ BackwardActType<double>::type operator()(activation_mode_t type) {
return backward_d[type];
}
};
} // namespace gpu
#else
namespace cpu {
static Active<float>::forward forward[] = FLOAT_ACTIVE_FUNCTION;
static Active<float>::backward backward[] = FLOAT_ACTIVE_FUNCTION;
static Active<double>::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION;
static Active<double>::backward backward_d[] = DOUBLE_ACTIVE_FUNCTION;
template <typename T>
struct ForwardAct {
typename ForwardActType<T>::type operator()(activation_mode_t type);
};
template <>
struct ForwardAct<float> {
ForwardActType<float>::type operator()(activation_mode_t type) {
return forward[type];
}
};
template <>
struct ForwardAct<double> {
ForwardActType<double>::type operator()(activation_mode_t type) {
return forward_d[type];
}
};
template <typename T>
struct BackwardAct {
typename BackwardActType<T>::type operator()(activation_mode_t type);
};
template <>
struct BackwardAct<float> {
BackwardActType<float>::type operator()(activation_mode_t type) {
return backward[type];
}
};
template <>
struct BackwardAct<double> {
BackwardActType<double>::type operator()(activation_mode_t type) {
return backward_d[type];
}
};
} // namespace cpu
#ifdef __AVX__
namespace avx {
static Active<__m256>::forward forward[] = AVX_ACTIVE_FUNCTION;
static Active<__m256>::backward backward[] = AVX_ACTIVE_FUNCTION;
} // namespace avx
#endif
#endif
} // namespace hppl
#endif // HL_ACTIVATION_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <immintrin.h>
#include "hl_functions.h"
// TODO(qingqing) refine this dependence
#include "paddle/cuda/src/avx_mathfun.h"
namespace hppl {
__m256 exp(__m256 a) { return exp256_ps(a); }
__m256 relu(const __m256 a) {
__m256 tmp = _mm256_set1_ps(0.0f);
return _mm256_max_ps(a, tmp);
}
__m256 sigmoid(const __m256 a) {
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);
__m256 tmp = _mm256_max_ps(a, min);
tmp = _mm256_min_ps(tmp, max);
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp);
tmp = exp(tmp);
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
return tmp;
}
__m256 tanh(const __m256 a) {
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
tmp = _mm256_min_ps(tmp, max);
tmp = exp(tmp);
return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f),
_mm256_add_ps(_mm256_set1_ps(1.0f), tmp)),
_mm256_set1_ps(1.0f));
}
__m256 linear(const __m256 a) { return a; }
__m256 relu(const __m256 a, const __m256 b) {
return _mm256_mul_ps(
a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS),
_mm256_set1_ps(1.0f)));
}
__m256 sigmoid(const __m256 a, const __m256 b) {
return _mm256_mul_ps(_mm256_mul_ps(a, b),
_mm256_sub_ps(_mm256_set1_ps(1.0f), b));
}
__m256 tanh(const __m256 a, const __m256 b) {
return _mm256_mul_ps(
a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b)));
}
__m256 linear(const __m256 a, const __m256 b) { return a; }
} // namespace hppl
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_AVX_FUNCTIONS_H_
#define HL_AVX_FUNCTIONS_H_
#include <immintrin.h>
namespace hppl {
__m256 relu(const __m256 a);
__m256 sigmoid(const __m256 a);
__m256 tanh(const __m256 a);
__m256 linear(const __m256 a);
__m256 relu(const __m256 a, const __m256 b);
__m256 sigmoid(const __m256 a, const __m256 b);
__m256 tanh(const __m256 a, const __m256 b);
__m256 linear(const __m256 a, const __m256 b);
} // namespace hppl
#endif // HL_AVX_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include "hl_functions.h"
namespace hppl {
namespace typef {
float relu(const float a) {
return a > static_cast<float>(0.0) ? a : static_cast<float>(0.0);
}
float sigmoid(const float a) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
float tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<float>(1.0) / (static_cast<float>(1.0) + exp(-tmp));
}
float tanh(const float a) {
float tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
float linear(const float a) { return a; }
float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); }
float sigmoid(const float a, const float b) {
return a * b * (static_cast<float>(1) - b);
}
float tanh(const float a, const float b) {
return a * (static_cast<float>(1) - b * b);
}
float linear(const float a, const float b) { return a; }
} // namespace typef
namespace typed {
double relu(const double a) {
return a > static_cast<double>(0.0) ? a : static_cast<double>(0.0);
}
double sigmoid(const double a) {
const double min = SIGMOID_THRESHOLD_MIN;
const double max = SIGMOID_THRESHOLD_MAX;
double tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<double>(1.0) / (static_cast<double>(1.0) + exp(-tmp));
}
double tanh(const double a) {
double tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
double linear(const double a) { return a; }
double relu(const double a, const double b) {
return a * (b > 0.0 ? 1.0 : 0.0);
}
double sigmoid(const double a, const double b) {
return a * b * (static_cast<double>(1) - b);
}
double tanh(const double a, const double b) {
return a * (static_cast<double>(1) - b * b);
}
double linear(const double a, const double b) { return a; }
} // namespace typed
} // namespace hppl
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_FUNCTIONS_H_
#define HL_FUNCTIONS_H_
/**
* sigmoid threshold maximum
*/
#define SIGMOID_THRESHOLD_MIN -40.0
/**
* sigmoid threshold minimum
*/
#define SIGMOID_THRESHOLD_MAX 13.0
/**
* The maximum input value for exp, used to avoid overflow problem.
* currently only used for tanh function.
*/
#define EXP_MAX_INPUT 40.0
#ifndef __NVCC__
namespace hppl {
namespace typef {
float relu(const float a);
float sigmoid(const float a);
float tanh(const float a);
float linear(const float a);
float relu(const float a, const float b);
float sigmoid(const float a, const float b);
float tanh(const float a, const float b);
float linear(const float a, const float b);
} // namespace typef
namespace typed {
double relu(const double a);
double sigmoid(const double a);
double tanh(const double a);
double linear(const double a);
double relu(const double a, const double b);
double sigmoid(const double a, const double b);
double tanh(const double a, const double b);
double linear(const double a, const double b);
} // namespace typed
} // namespace hppl
#ifdef __AVX__
#include "hl_avx_functions.h"
#endif
#else
#include "hl_gpu_functions.h"
#endif
#endif // HL_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_GPU_FUNCTIONS_CUH_
#define HL_GPU_FUNCTIONS_CUH_
#include "hl_base.h"
namespace hppl {
namespace typef {
__device__ static float relu(const float a) { return a > 0.0f ? a : 0.0f; }
__device__ static float sigmoid(const float a) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
float tmp = (a < min) ? min : ((a > max) ? max : a);
return __fdividef(1.0f, 1.0f + __expf(-tmp));
}
__device__ static float tanh(const float a) {
float tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return __fdividef(2.0f, (1.0f + __expf(-2.0f * tmp))) - 1.0f;
}
__device__ static float linear(const float a) { return a; }
__device__ static float relu(const float a, const float b) {
return a * (b > 0.0f ? 1.0f : 0.0f);
}
__device__ static float sigmoid(const float a, const float b) {
return a * b * (1.0f - b);
}
__device__ static float tanh(const float a, const float b) {
return a * (1.0f - b * b);
}
__device__ static float linear(const float a, const float b) { return a; }
} // namespace typef
namespace typed {
__device__ static double relu(const double a) { return a > 0.0 ? a : 0.0; }
__device__ static double sigmoid(const double a) {
const double min = SIGMOID_THRESHOLD_MIN;
const double max = SIGMOID_THRESHOLD_MAX;
double tmp = (a < min) ? min : ((a > max) ? max : a);
return 1.0 / (1.0 + exp(-tmp));
}
__device__ static double tanh(const double a) {
double tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0;
}
__device__ static double linear(const double a) { return a; }
__device__ static double relu(const double a, const double b) {
return a * (b > 0.0 ? 1.0 : 0.0);
}
__device__ static double sigmoid(const double a, const double b) {
return a * b * (1 - b);
}
__device__ static double tanh(const double a, const double b) {
return a * (1.0 - b * b);
}
__device__ static double linear(const double a, const double b) { return a; }
} // namespace typef
} // namespace hppl
#endif // HL_GPU_FUNCTIONS_CUH_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/lstm_compute.h"
namespace paddle {
namespace operators {
namespace math {
namespace detail {
#ifndef __NVCC__
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rCheckI;
T rCheckF;
T rCheckO;
T rState;
T rPrevState = 0;
T rStateAtv;
T rOut;
T *valueIn = value.gateValue;
T *valueIg = value.gateValue + frameSize;
T *valueFg = value.gateValue + frameSize * 2;
T *valueOg = value.gateValue + frameSize * 3;
for (int i = 0; i < frameSize; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = value.checkIg[i];
rCheckF = value.checkFg[i];
rCheckO = value.checkOg[i];
if (value.prevStateValue) {
rPrevState = value.prevStateValue[i];
}
hppl::cpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
act(active_state));
valueIn[i] = rValueIn;
valueIg[i] = rValueIg;
valueFg[i] = rValueFg;
valueOg[i] = rValueOg;
value.stateValue[i] = rState;
value.stateActiveValue[i] = rStateAtv;
value.outputValue[i] = rOut;
}
}
template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rGradIn;
T rGradIg;
T rGradFg;
T rGradOg;
T rPrevState = 0;
T rPrevStateGrad;
T rState;
T rStateGrad;
T rStateAtv;
T rOutputGrad;
T rCheckI;
T rCheckF;
T rCheckO;
T rCheckIGrad;
T rCheckFGrad;
T rCheckOGrad;
T *valueIn = value.gateValue;
T *valueIg = value.gateValue + frameSize;
T *valueFg = value.gateValue + frameSize * 2;
T *valueOg = value.gateValue + frameSize * 3;
T *gradIn = grad.gateGrad;
T *gradIg = grad.gateGrad + frameSize;
T *gradFg = grad.gateGrad + frameSize * 2;
T *gradOg = grad.gateGrad + frameSize * 3;
for (int i = 0; i < frameSize; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = value.checkIg[i];
rCheckF = value.checkFg[i];
rCheckO = value.checkOg[i];
rState = value.stateValue[i];
rStateAtv = value.stateActiveValue[i];
rOutputGrad = grad.outputGrad[i];
rStateGrad = grad.stateGrad[i];
if (value.prevStateValue) {
rPrevState = value.prevStateValue[i];
}
hppl::cpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, act(active_node), act(active_gate), act(active_state));
gradIn[i] = rGradIn;
gradIg[i] = rGradIg;
gradFg[i] = rGradFg;
gradOg[i] = rGradOg;
grad.stateGrad[i] = rStateGrad;
if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad;
if (value.prevStateValue) {
if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad;
if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad;
}
if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad;
}
}
template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
#ifdef __AVX__
__m256 rValueIn;
__m256 rValueIg;
__m256 rValueFg;
__m256 rValueOg;
__m256 rCheckI;
__m256 rCheckF;
__m256 rCheckO;
__m256 rState;
__m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rStateAtv;
__m256 rOut;
__m256 *valueIn = (__m256 *)value.gateValue;
__m256 *valueIg = (__m256 *)(value.gateValue + frameSize);
__m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2);
__m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3);
for (int i = 0; i < frameSize / 8; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, hppl::avx::forward[active_node],
hppl::avx::forward[active_gate], hppl::avx::forward[active_state]);
valueIn[i] = rValueIn;
valueIg[i] = rValueIg;
valueFg[i] = rValueFg;
valueOg[i] = rValueOg;
((__m256 *)value.stateValue)[i] = rState;
((__m256 *)value.stateActiveValue)[i] = rStateAtv;
((__m256 *)value.outputValue)[i] = rOut;
}
#endif
}
template <class T, class Op>
void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
#ifdef __AVX__
__m256 rValueIn;
__m256 rValueIg;
__m256 rValueFg;
__m256 rValueOg;
__m256 rGradIn;
__m256 rGradIg;
__m256 rGradFg;
__m256 rGradOg;
__m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rPrevStateGrad;
__m256 rStateGrad;
__m256 rState;
__m256 rStateAtv;
__m256 rOutputGrad;
__m256 rCheckI;
__m256 rCheckF;
__m256 rCheckO;
__m256 rCheckIGrad;
__m256 rCheckFGrad;
__m256 rCheckOGrad;
__m256 *valueIn = (__m256 *)value.gateValue;
__m256 *valueIg = (__m256 *)(value.gateValue + frameSize);
__m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2);
__m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3);
__m256 *gradIn = (__m256 *)grad.gateGrad;
__m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize);
__m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2);
__m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3);
for (int i = 0; i < frameSize / 8; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
rState = ((__m256 *)value.stateValue)[i];
rStateAtv = ((__m256 *)value.stateActiveValue)[i];
rOutputGrad = ((__m256 *)grad.outputGrad)[i];
rStateGrad = ((__m256 *)grad.stateGrad)[i];
if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, hppl::avx::backward[active_node],
hppl::avx::backward[active_gate], hppl::avx::backward[active_state]);
gradIn[i] = rGradIn;
gradIg[i] = rGradIg;
gradFg[i] = rGradFg;
gradOg[i] = rGradOg;
((__m256 *)grad.stateGrad)[i] = rStateGrad;
if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad;
if (value.prevStateValue) {
if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad;
if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad;
}
if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad;
}
#endif
}
template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state);
} else {
naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state);
}
}
template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frameSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
} else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
}
}
#endif
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"
#include <glog/logging.h>
namespace paddle {
namespace operators {
namespace math {
namespace detail {
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class T, class Op, bool isBatch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
value.gateValue += batchIdx * frameSize * 4;
value.outputValue += batchIdx * frameSize;
value.stateValue += batchIdx * frameSize;
value.stateActiveValue += batchIdx * frameSize;
}
T rState;
T rPrevState = 0;
T rStateAtv;
T rOut;
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rCheckI = value.checkIg[frameIdx];
T rCheckF = value.checkFg[frameIdx];
T rCheckO = value.checkOg[frameIdx];
rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize];
rValueFg = value.gateValue[frameIdx + frameSize * 2];
rValueOg = value.gateValue[frameIdx + frameSize * 3];
if (value.prevStateValue) {
if (isBatch) value.prevStateValue += batchIdx * frameSize;
rPrevState = value.prevStateValue[frameIdx];
}
hppl::gpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
act(active_state));
value.gateValue[frameIdx] = rValueIn;
value.gateValue[frameIdx + frameSize] = rValueIg;
value.gateValue[frameIdx + frameSize * 2] = rValueFg;
value.gateValue[frameIdx + frameSize * 3] = rValueOg;
value.stateValue[frameIdx] = rState;
value.stateActiveValue[frameIdx] = rStateAtv;
value.outputValue[frameIdx] = rOut;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class T, class Op, bool isBatch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
value.gateValue += batchIdx * frameSize * 4;
value.stateValue += batchIdx * frameSize;
value.stateActiveValue += batchIdx * frameSize;
grad.gateGrad += batchIdx * frameSize * 4;
grad.stateGrad += batchIdx * frameSize;
grad.outputGrad += batchIdx * frameSize;
}
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rGradIn;
T rGradIg;
T rGradFg;
T rGradOg;
T rPrevState = 0;
T rPrevStateGrad;
T rState;
T rStateGrad;
T rStateAtv;
T rOutputGrad;
T rCheckI = value.checkIg[frameIdx];
T rCheckF = value.checkFg[frameIdx];
T rCheckO = value.checkOg[frameIdx];
T rCheckIGrad;
T rCheckFGrad;
T rCheckOGrad;
rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize];
rValueFg = value.gateValue[frameIdx + frameSize * 2];
rValueOg = value.gateValue[frameIdx + frameSize * 3];
rState = value.stateValue[frameIdx];
rStateAtv = value.stateActiveValue[frameIdx];
rOutputGrad = grad.outputGrad[frameIdx];
rStateGrad = grad.stateGrad[frameIdx];
if (value.prevStateValue) {
if (isBatch) value.prevStateValue += batchIdx * frameSize;
rPrevState = value.prevStateValue[frameIdx];
}
hppl::gpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
act(active_node), act(active_gate), act(active_state));
grad.gateGrad[frameIdx] = rGradIn;
grad.gateGrad[frameIdx + frameSize] = rGradIg;
grad.gateGrad[frameIdx + frameSize * 2] = rGradFg;
grad.gateGrad[frameIdx + frameSize * 3] = rGradOg;
grad.stateGrad[frameIdx] = rStateGrad;
if (grad.prevStateGrad) {
if (isBatch) grad.prevStateGrad += batchIdx * frameSize;
grad.prevStateGrad[frameIdx] = rPrevStateGrad;
}
if (isBatch) {
if (value.prevStateValue) {
if (grad.checkIgGrad)
paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx,
rCheckIGrad);
if (grad.checkFgGrad)
paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx,
rCheckFGrad);
}
if (grad.checkOgGrad)
paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad);
} else {
if (value.prevStateValue) {
if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad;
if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad;
}
if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad;
}
}
template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
/* framePerBlock = 32 batchPerBlock = 32 */
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
}
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
if (batchSize == 1) {
KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
} else {
KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
}
}
template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
/* framePerBlock = 32 batchPerBlock = 32 */
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
}
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
if (batchSize == 1) {
KeLstmBackward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
} else {
KeLstmBackward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
}
}
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/platform/hostdevice.h"
#include <type_traits>
namespace paddle {
namespace operators {
namespace math {
namespace detail {
namespace forward {
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO,
typename hppl::ForwardActType<T>::type actInput,
typename hppl::ForwardActType<T>::type actGate,
typename hppl::ForwardActType<T>::type actState) {
valueIn = actInput(valueIn);
valueIg = actGate(valueIg + prevState * checkI);
valueFg = actGate(valueFg + prevState * checkF);
state = valueIn * valueIg + prevState * valueFg;
valueOg = actGate(valueOg + state * checkO);
stateAtv = actState(state);
output = valueOg * stateAtv;
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false;
#else
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &prevState, __m256 &state,
__m256 &stateAtv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO,
hppl::Active<__m256>::forward actInput,
hppl::Active<__m256>::forward actGate,
hppl::Active<__m256>::forward actState) {
valueIn = actInput(valueIn);
valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)));
valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)));
state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg),
_mm256_mul_ps(prevState, valueFg));
valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)));
stateAtv = actState(state);
output = _mm256_mul_ps(valueOg, stateAtv);
}
#endif
#endif
};
} // namespace forward
namespace backward {
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
T &prevState, T &prevStateGrad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad,
T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad,
typename hppl::BackwardActType<T>::type actInput,
typename hppl::BackwardActType<T>::type actGate,
typename hppl::BackwardActType<T>::type actState) {
gradOg = actGate(outputGrad * stateAtv, valueOg);
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = actInput(stateGrad * valueIg, valueIn);
gradIg = actGate(stateGrad * valueIn, valueIg);
gradFg = actGate(stateGrad * prevState, valueFg);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false;
#else
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
__m256 &prevStateGrad, __m256 &state,
__m256 &stateGrad, __m256 &stateAtv,
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
__m256 &checkO, __m256 &checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad,
hppl::Active<__m256>::backward actInput,
hppl::Active<__m256>::backward actGate,
hppl::Active<__m256>::backward actState) {
gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg);
stateGrad = _mm256_add_ps(
actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad);
stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad);
gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn);
gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg);
gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg);
prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI),
_mm256_mul_ps(gradFg, checkF));
prevStateGrad =
_mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad);
checkIGrad = _mm256_mul_ps(gradIg, prevState);
checkFGrad = _mm256_mul_ps(gradFg, prevState);
checkOGrad = _mm256_mul_ps(gradOg, state);
}
#endif
#endif
};
} // namespace backward
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -29,8 +29,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -41,6 +41,22 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
int channels_col = input_channels * filter_height * filter_width;
const T* im_data = im.data<T>();
......@@ -52,16 +68,14 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_height) < 0 ||
(im_row_idx - padding_height) >= input_height ||
(im_col_idx - padding_width) < 0 ||
(im_col_idx - padding_width) >= input_width) {
int im_row_idx = h * stride_height + h_offset - padding_up;
int im_col_idx = w * stride_width + w_offset - padding_left;
if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 ||
im_col_idx >= input_width) {
col_data[(c * output_height + h) * output_width + w] = T(0);
} else {
im_row_idx += c_im * input_height - padding_height;
im_col_idx -= padding_width;
im_row_idx += c_im * input_height;
col_data[(c * output_height + h) * output_width + w] =
im_data[im_row_idx * input_width + im_col_idx];
}
......@@ -82,7 +96,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -92,6 +107,22 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
int channels_col = input_channels * filter_height * filter_width;
T* im_data = im.data<T>();
......@@ -103,14 +134,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_height) >= 0 &&
(im_row_idx - padding_height) < input_height &&
(im_col_idx - padding_width) >= 0 &&
(im_col_idx - padding_width) < input_width) {
im_row_idx += c_im * input_height - padding_height;
im_col_idx -= padding_width;
int im_row_idx = h * stride_height + h_offset - padding_up;
int im_col_idx = w * stride_width + w_offset - padding_left;
if ((im_row_idx) >= 0 && (im_row_idx) < input_height &&
(im_col_idx) >= 0 && (im_col_idx) < input_width) {
im_row_idx += c_im * input_height;
im_data[im_row_idx * input_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w];
}
......@@ -140,8 +169,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -152,6 +181,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
const T* im_data = im.data<T>();
T* col_data = col.data<T>();
......@@ -163,10 +207,10 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_height;
col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width;
int col_offset = (((col_row_idx * output_width + col_col_idx) *
col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = ((((col_row_idx)*output_width + col_col_idx) *
input_channels +
channel) *
filter_height +
......@@ -201,7 +245,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -212,6 +257,21 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
T* im_data = im.data<T>();
const T* col_data = col.data<T>();
......@@ -223,9 +283,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_height;
col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width;
col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = (((col_row_idx * output_width + col_col_idx) *
input_channels +
channel) *
......
......@@ -66,8 +66,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -79,6 +79,15 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int output_height = col.dims()[3];
int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int num_outputs = input_channels * output_height * output_width;
int blocks = (num_outputs + 1024 - 1) / 1024;
int block_x = 512;
......@@ -89,8 +98,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height,
filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, col.data<T>());
filter_width, stride_height, stride_width, padding_up, padding_left,
output_height, output_width, col.data<T>());
}
};
......@@ -152,7 +161,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -164,8 +174,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int output_height = col.dims()[3];
int output_width = col.dims()[4];
size_t num_kernels = input_channels * (input_height + 2 * padding_height) *
(input_width + 2 * padding_width);
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
size_t num_kernels = input_channels *
(input_height + padding_up + padding_down) *
(input_width + padding_left + padding_right);
size_t blocks = (num_kernels + 1024 - 1) / 1024;
size_t block_x = 512;
......@@ -178,10 +198,10 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
num_kernels, col.data<T>(), input_height + 2 * padding_height,
input_width + 2 * padding_width, input_channels, filter_height,
filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, im.data<T>());
num_kernels, col.data<T>(), input_height + padding_up + padding_down,
input_width + padding_left + padding_left, input_channels,
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width, im.data<T>());
}
};
......@@ -238,8 +258,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -250,6 +270,15 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int block_dim_x = 0;
int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) {
......@@ -274,8 +303,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width);
}
};
......@@ -322,7 +351,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -333,6 +363,15 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int block_dim_x = 0;
int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) {
......@@ -357,8 +396,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width);
}
};
......
......@@ -74,8 +74,8 @@ class Im2ColFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width);
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right);
};
template <ColFormat Format, typename Place, typename T>
......@@ -83,7 +83,8 @@ class Col2ImFunctor {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width);
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right);
};
} // namespace math
......
......@@ -35,6 +35,12 @@ void testIm2col() {
*
* output_ocf = [0, 1, 3, 4
* 1, 2, 4, 5]
*
* col2im_cfo = [0, 2, 2
* 3, 4, 5]
*
* col2im_ocf = [0, 2, 2
* 3, 4, 5]
*/
int input_height = 2;
int input_width = 3;
......@@ -59,7 +65,7 @@ void testIm2col() {
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#else
PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU
#endif // PADDLE_WITH_CUDA
}
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
......@@ -71,6 +77,7 @@ void testIm2col() {
output_ocf.mutable_data<float>(
{output_height, output_width, 1, filter_size, filter_size}, *place);
// Im2Col
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, float>
im2col;
......@@ -78,8 +85,13 @@ void testIm2col() {
paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf;
im2col(*context, input, output_cfo, stride, stride, padding, padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding);
im2col(*context, input, output_cfo, stride, stride, padding, padding, padding,
padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding,
padding, padding);
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) {
......@@ -88,14 +100,9 @@ void testIm2col() {
output_tmp.CopyFrom(output_cfo, paddle::platform::CPUPlace(), *context);
out_cfo_ptr = output_tmp.data<float>();
}
EXPECT_EQ(out_cfo_ptr[0], 0);
EXPECT_EQ(out_cfo_ptr[1], 1);
EXPECT_EQ(out_cfo_ptr[2], 1);
EXPECT_EQ(out_cfo_ptr[3], 2);
EXPECT_EQ(out_cfo_ptr[4], 3);
EXPECT_EQ(out_cfo_ptr[5], 4);
EXPECT_EQ(out_cfo_ptr[6], 4);
EXPECT_EQ(out_cfo_ptr[7], 5);
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(out_cfo_ptr[i], out_cfo_data[i]);
}
float* out_ocf_ptr;
if (paddle::platform::is_cpu_place(*place)) {
......@@ -104,14 +111,60 @@ void testIm2col() {
output_tmp.CopyFrom(output_ocf, paddle::platform::CPUPlace(), *context);
out_ocf_ptr = output_tmp.data<float>();
}
EXPECT_EQ(out_ocf_ptr[0], 0);
EXPECT_EQ(out_ocf_ptr[1], 1);
EXPECT_EQ(out_ocf_ptr[2], 3);
EXPECT_EQ(out_ocf_ptr[3], 4);
EXPECT_EQ(out_ocf_ptr[4], 1);
EXPECT_EQ(out_ocf_ptr[5], 2);
EXPECT_EQ(out_ocf_ptr[6], 4);
EXPECT_EQ(out_ocf_ptr[7], 5);
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(out_ocf_ptr[i], out_ocf_data[i]);
}
// Col2Im: kCFO
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, float>
col2im;
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
col2im_ocf;
float col2im_data[] = {0, 2, 2, 3, 8, 5};
memset(input_ptr, 0, 6 * sizeof(float));
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom(input_tmp, *place, *context);
}
col2im(*context, input, output_cfo, stride, stride, padding, padding, padding,
padding);
float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
} else {
input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context);
in_ptr = input_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(in_ptr[i], col2im_data[i]);
}
// Col2Im: kOCF
memset(input_ptr, 0, 6 * sizeof(float));
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom(input_tmp, *place, *context);
}
col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding,
padding, padding);
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
} else {
input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context);
in_ptr = input_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(in_ptr[i], col2im_data[i]);
}
}
TEST(math, im2col) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/detail/lstm_cpu_kernel.h"
#include "paddle/operators/math/detail/lstm_kernel.h"
namespace paddle {
namespace operators {
namespace math {
template <class T>
struct LstmUnitFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
ActiveType(cand_act), ActiveType(gate_act),
ActiveType(cell_act));
value.gateValue += frame_size * 4;
value.stateValue += frame_size;
value.stateActiveValue += frame_size;
value.outputValue += frame_size;
if (value.prevStateValue) {
value.prevStateValue += frame_size;
}
}
}
};
template <class T>
struct LstmUnitGradFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frame_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
value.gateValue += frame_size * 4;
value.stateValue += frame_size;
value.stateActiveValue += frame_size;
value.outputValue += frame_size;
if (value.prevStateValue) {
value.prevStateValue += frame_size;
}
grad.gateGrad += frame_size * 4;
grad.stateGrad += frame_size;
grad.stateActiveGrad += frame_size;
grad.outputGrad += frame_size;
if (grad.prevStateGrad) {
grad.prevStateGrad += frame_size;
}
}
}
};
template class LstmUnitFunctor<platform::CPUPlace, float>;
template class LstmUnitFunctor<platform::CPUPlace, double>;
template class LstmUnitGradFunctor<platform::CPUPlace, float>;
template class LstmUnitGradFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/lstm_gpu_kernel.h"
#include "paddle/operators/math/detail/lstm_kernel.h"
#include "paddle/operators/math/lstm_compute.h"
namespace paddle {
namespace operators {
namespace math {
template <class T>
struct LstmUnitFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
}
};
template <class T>
struct LstmUnitGradFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
}
};
template class LstmUnitFunctor<platform::GPUPlace, float>;
template class LstmUnitFunctor<platform::GPUPlace, double>;
template class LstmUnitGradFunctor<platform::GPUPlace, float>;
template class LstmUnitGradFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
typedef enum {
HL_ACTIVATION_SIGMOID = 0,
HL_ACTIVATION_RELU = 1,
HL_ACTIVATION_TANH = 2,
HL_ACTIVATION_LINEAR = 3,
HL_ACTIVATION_END
} activation_mode_t;
template <class T>
struct LstmMetaValue {
T *gateValue;
T *prevStateValue;
T *stateValue;
T *stateActiveValue;
T *outputValue;
T *checkIg;
T *checkFg;
T *checkOg;
};
template <class T>
struct LstmMetaGrad {
T *gateGrad;
T *prevStateGrad;
T *stateGrad;
T *stateActiveGrad;
T *outputGrad;
T *checkIgGrad;
T *checkFgGrad;
T *checkOgGrad;
};
inline activation_mode_t ActiveType(const std::string &type) {
if (type == "sigmoid") {
return HL_ACTIVATION_SIGMOID;
} else if (type == "relu") {
return HL_ACTIVATION_RELU;
} else if (type == "tanh") {
return HL_ACTIVATION_TANH;
} else if (type == "linear" || type == "identity" || type == "") {
return HL_ACTIVATION_LINEAR;
} else {
PADDLE_THROW("Do not support activation type.");
}
}
template <typename Place, typename T>
class LstmUnitFunctor {
public:
static void compute(const platform::DeviceContext &context,
LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act,
const std::string &cand_act);
};
template <typename Place, typename T>
class LstmUnitGradFunctor {
public:
static void compute(const platform::DeviceContext &context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act,
const std::string &cand_act);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index) {
auto src_dims = src.dims();
auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2UL,
"The src must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL,
"The dst must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1],
"The width of src and dst must be same.");
auto height = dst_dims[0];
auto width = dst_dims[1];
auto* src_data = src.data<T>();
auto* dst_data = dst.data<T>();
for (int i = 0; i < height; ++i) {
if (is_src_index) {
memcpy(dst_data + i * width, src_data + index[i] * width,
width * sizeof(T));
} else {
memcpy(dst_data + index[i] * width, src_data + i * width,
width * sizeof(T));
}
}
}
};
template class CopyMatrixRowsFunctor<platform::CPUPlace, float>;
template class CopyMatrixRowsFunctor<platform::CPUPlace, double>;
template class LoDTensor2BatchFunctor<platform::CPUPlace, float>;
template class LoDTensor2BatchFunctor<platform::CPUPlace, double>;
template class Batch2LoDTensorFunctor<platform::CPUPlace, float>;
template class Batch2LoDTensorFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index,
int64_t height, int64_t width,
bool is_src_index) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int id = blockIdx.x + idy * GridDimX;
while (id < height) {
int src_idx = is_src_index ? index[id] : id;
int dst_idx = is_src_index ? id : index[id];
const T* src_data = src + src_idx * width;
T* dst_data = dst + dst_idx * width;
for (int i = idx; i < width; i += BlockDimX) {
dst_data[i] = src_data[i];
}
id += BlockDimY * GridDimX;
}
}
template <typename T>
class CopyMatrixRowsFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index) {
auto src_dims = src.dims();
auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2,
"The src must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(dst_dims.size(), 2,
"The dst must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1],
"The width of src and dst must be same.");
auto height = dst_dims[0];
auto width = dst_dims[1];
auto* src_data = src.data<T>();
auto* dst_data = dst.data<T>();
dim3 threads(128, 8);
dim3 grid(8, 1);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>(
src_data, dst_data, index, height, width, is_src_index);
}
};
template class CopyMatrixRowsFunctor<platform::GPUPlace, float>;
template class CopyMatrixRowsFunctor<platform::GPUPlace, double>;
template class LoDTensor2BatchFunctor<platform::GPUPlace, float>;
template class LoDTensor2BatchFunctor<platform::GPUPlace, double>;
template class Batch2LoDTensorFunctor<platform::GPUPlace, float>;
template class Batch2LoDTensorFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
class CopyMatrixRowsFunctor {
public:
// If is_src_index is true,
// copy the indexed rows of input src to the output dst.
// If is_src_index is false,
// copy the input src to the indexed rows of output dst.
// The indexed rows are based on the input index.
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index);
};
template <typename Place, typename T>
class LoDTensor2BatchFunctor {
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
//
struct SeqInfo {
SeqInfo(int start, int length, int seq_idx)
: start(start), length(length), seq_idx(seq_idx) {}
int start;
int length;
int seq_idx;
};
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& lod_tensor,
framework::LoDTensor& batch, bool is_reverse) const {
auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
auto lod = lods[0];
std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
seq_info.emplace_back(lod[seq_id], length, seq_id);
}
std::sort(seq_info.begin(), seq_info.end(),
[](SeqInfo a, SeqInfo b) { return a.length > b.length; });
// calculate the start position of each batch
// (numBatch equal the maxLength of sequences)
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// num_batch = 5,
// batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = len(b0)
// batch_start_positions[1] = len(b0) + len(b1)
// batch_start_positions[2] = len(b0) + len(b1) + len(b2)
// ...
// seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10,
// 6, 2, 11,
// 7, 3,
// 8}
// The batch number represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence.
paddle::framework::LoD batch_lods;
batch_lods.emplace_back(std::vector<size_t>{0});
batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor
int num_batch = seq_info[0].length;
batch_lods[0].resize(static_cast<size_t>(num_batch + 1));
// batch_lods[1] is the raw index in the input LoDTensor
auto dims = lod_tensor.dims();
batch_lods[1].resize(static_cast<size_t>(dims[0]));
size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (size_t n = 0; n < num_batch; n++) {
auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length;
int start = seq_info[i].start;
if (n < seq_len) {
seq2batch_idx[batch_id] =
is_reverse ? start + seq_len - 1 - n : start + n;
batch_id++;
} else {
break;
}
}
batch_starts[n + 1] = static_cast<size_t>(batch_id);
}
batch.set_lod(batch_lods);
CopyMatrixRowsFunctor<Place, T> to_batch;
to_batch(context, lod_tensor, seq2batch_idx, batch, true);
}
};
template <typename Place, typename T>
class Batch2LoDTensorFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& batch,
framework::LoDTensor& lod_tensor) const {
auto in_lod = batch.lod();
PADDLE_ENFORCE_EQ(in_lod.size(), 2UL,
"The LoD size of input `batch` should be 2.");
auto out_lod = lod_tensor.lod()[0];
auto num = out_lod[out_lod.size() - 1];
PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]);
PADDLE_ENFORCE_EQ(num, in_lod[1].size());
PADDLE_ENFORCE_EQ(num, batch.dims()[0]);
CopyMatrixRowsFunctor<Place, T> to_seq;
size_t* index = in_lod[1].data();
to_seq(context, batch, index, lod_tensor, false);
}
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -49,7 +49,19 @@ class MulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height.");
ctx->SetOutputDim("Out", {x_mat_dims[0], y_mat_dims[1]});
std::vector<int64_t> output_dims;
output_dims.reserve(
static_cast<size_t>(x_num_col_dims + y_dims.size() - y_num_col_dims));
for (int i = 0; i < x_num_col_dims; ++i) {
output_dims.push_back(x_dims[i]);
}
for (int i = y_num_col_dims; i < y_dims.size(); ++i) {
output_dims.push_back(y_dims[i]);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
......@@ -109,15 +121,6 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto y_mat_dims = framework::flatten_to_2d(
y_dims, ctx->Attrs().Get<int>("y_num_col_dims"));
PADDLE_ENFORCE_EQ(
x_mat_dims[0], out_dims[0],
"The first dimension of Out@GRAD must equal to the first dimension of "
"the first operand.");
PADDLE_ENFORCE_EQ(
y_mat_dims[1], out_dims[1],
"The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
......
......@@ -46,8 +46,15 @@ class MulKernel : public framework::OpKernel<T> {
: *y;
z->mutable_data<T>(context.GetPlace());
auto z_dim = z->dims();
if (z_dim.size() != 2) {
z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
math::matmul<Place, T>(context.device_context(), x_matrix, false, y_matrix,
false, 1, z, 0);
if (z_dim.size() != 2) {
z->Resize(z_dim);
}
}
};
......@@ -67,6 +74,11 @@ class MulGradKernel : public framework::OpKernel<T> {
: *y;
const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]});
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (dx) {
......@@ -74,9 +86,10 @@ class MulGradKernel : public framework::OpKernel<T> {
Tensor dx_matrix = dx->dims().size() > 2
? framework::ReshapeToMatrix(*dx, x_num_col_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
math::matmul<Place, T>(ctx.device_context(), *dout, false, y_matrix, true,
1, &dx_matrix, 0);
math::matmul<Place, T>(ctx.device_context(), dout_mat, false, y_matrix,
true, 1, &dx_matrix, 0);
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
......@@ -84,8 +97,8 @@ class MulGradKernel : public framework::OpKernel<T> {
? framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
math::matmul<Place, T>(ctx.device_context(), x_matrix, true, *dout, false,
1, &dy_matrix, 0);
math::matmul<Place, T>(ctx.device_context(), x_matrix, true, dout_mat,
false, 1, &dy_matrix, 0);
}
}
};
......
......@@ -160,66 +160,6 @@ class ReduceMinOpMaker : public ReduceOpMaker {
}
};
class NormOp : public NetOp {
public:
NormOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName,
"Input(X) of NormOp should not be null.");
PADDLE_ENFORCE_NE(Output("AbsOut"), framework::kEmptyVarName,
"Output(AbsOut) of NormOp should not be null.");
PADDLE_ENFORCE_NE(Output("PowOut"), framework::kEmptyVarName,
"Output(PowOut) of NormOp should not be null.");
PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName,
"Output(SumOut) of NormOp should not be null.");
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
"Output(Out) of NormOp should not be null.");
auto dim = Attr<int>("dim");
auto keep_dim = Attr<bool>("keep_dim");
auto p = Attr<float>("p");
PADDLE_ENFORCE_GT(p, 0, "Order of the norm should be positive.");
AppendOp(framework::OpRegistry::CreateOp("abs", {{"X", {Input("X")}}},
{{"Y", {Output("AbsOut")}}}, {}));
AppendOp(framework::OpRegistry::CreateOp("pow", {{"X", {Output("AbsOut")}}},
{{"Y", {Output("PowOut")}}},
{{"factor", p}}));
framework::AttributeMap sum_attr;
sum_attr["dim"] = dim;
sum_attr["keep_dim"] = keep_dim;
AppendOp(framework::OpRegistry::CreateOp(
"reduce_sum", {{"X", {Output("PowOut")}}},
{{"Out", {Output("SumOut")}}}, sum_attr));
AppendOp(framework::OpRegistry::CreateOp(
"pow", {{"X", {Output("SumOut")}}}, {{"Y", {Output("Out")}}},
{{"factor", static_cast<float>(1. / p)}}));
CompleteAddOp(false);
}
};
class NormOpMaker : public ReduceOpMaker {
public:
NormOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
AddOutput("AbsOut",
"(Tensor) The intermediate output of Norm operator, "
"saving the absolute value of the input tensor X.")
.AsIntermediate();
AddOutput("PowOut",
"(Tensor) The intermediate output of Norm operator, "
"saving the p-th power of the output tensor AbsOut.")
.AsIntermediate();
AddOutput("SumOut",
"(Tensor) the intermediate output of Norm operator, "
"saving the sum of PowOut reduced on the given dimension.")
.AsIntermediate();
AddAttr<float>("p", "(float, default 2) The order of Norm.").SetDefault(2);
SetComment("Norm", "vector p-norm");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
......@@ -237,8 +177,6 @@ REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad,
REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker, reduce_min_grad,
ops::ReduceGradOp);
REGISTER_OP_WITHOUT_GRADIENT(norm, ops::NormOp, ops::NormOpMaker);
#define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
reduce_type, \
......
......@@ -62,11 +62,13 @@ class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("InsideWeight",
"Optional input tensor of smooth l1 loss op with the same shape "
"as X. If provided, the result of (X - Y) will be multiplied "
"by this tensor element by element.");
"by this tensor element by element.")
.AsDispensable();
AddInput("OutsideWeight",
"Optinal input of smooth l1 loss op with the same shape as X."
"If provided, the output smooth l1 loss will be multiplied by "
"this tensor element by element.");
"this tensor element by element.")
.AsDispensable();
AddOutput("Diff", "Intermediate variable to cache InsideWeight*(X-Y).")
.AsIntermediate();
AddOutput("Out", "Smooth l1 loss.");
......
......@@ -25,3 +25,4 @@ nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context)
nv_test(nccl_test SRCS nccl_test.cu DEPS dynload_cuda gpu_info device_context)
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc DEPS dynamic_loader)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc
DEPS dynamic_loader nccl)
......@@ -35,6 +35,11 @@ DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so.");
DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so.");
DEFINE_string(nccl_dir, "",
"Specify path for loading nccl library, such as libcublas, "
"libcurand. For instance, /usr/local/cuda/lib64. If default, "
"dlopen will search cuda from LD_LIBRARY_PATH");
namespace paddle {
namespace platform {
namespace dynload {
......@@ -157,6 +162,14 @@ void GetLapackDsoHandle(void** dso_handle) {
#endif
}
void GetNCCLDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.so", dso_handle);
#endif
}
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -58,6 +58,14 @@ void GetWarpCTCDsoHandle(void** dso_handle);
*/
void GetLapackDsoHandle(void** dso_handle);
/**
* @brief load the DSO of NVIDIA nccl
*
* @param **dso_handle dso handler
*
*/
void GetNCCLDsoHandle(void** dso_handle);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/dynload/nccl.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag nccl_dso_flag;
void *nccl_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
NCCL_RAND_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <dlfcn.h>
#include <nccl.h>
#include <mutex>
#include "paddle/platform/dynload/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag nccl_dso_flag;
extern void* nccl_dso_handle;
#ifdef PADDLE_USE_DSO
#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using nccl_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(nccl_dso_flag, \
paddle::platform::dynload::GetNCCLDsoHandle, \
&nccl_dso_handle); \
void* p_##__name = dlsym(nccl_dso_handle, #__name); \
return reinterpret_cast<nccl_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#else
#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
ncclResult_t operator()(Args... args) { \
return __name(args...); \
} \
}; \
extern DynLoad__##__name __name
#endif
#define NCCL_RAND_ROUTINE_EACH(__macro) \
__macro(ncclCommInitAll); \
__macro(ncclGetUniqueId); \
__macro(ncclCommInitRank); \
__macro(ncclCommDestroy); \
__macro(ncclCommCount); \
__macro(ncclCommCuDevice); \
__macro(ncclCommUserRank); \
__macro(ncclAllReduce); \
__macro(ncclBcast); \
__macro(ncclAllGather); \
__macro(ncclReduce); \
__macro(ncclGetErrorString);
NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -29,11 +29,14 @@ limitations under the License. */
#include <cxxabi.h> // for __cxa_demangle
#endif
#include <glog/logging.h>
#ifdef PADDLE_WITH_CUDA
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/dynload/nccl.h"
#include <cublas_v2.h>
#include <cudnn.h>
......@@ -172,6 +175,17 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
throw std::runtime_error(err + string::Sprintf(args...));
}
template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
ncclResult_t stat, const Args&... args) {
if (stat == ncclSuccess) {
return;
} else {
throw std::runtime_error(platform::dynload::ncclGetErrorString(stat) +
string::Sprintf(args...));
}
}
#endif // PADDLE_ONLY_CPU
template <typename T>
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/dynload/nccl.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/gpu_info.h"
#include <thrust/device_vector.h>
#include <memory>
#include <vector>
static int dev_count = 0;
namespace paddle {
namespace platform {
TEST(NCCL, init) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
PADDLE_ENFORCE(status);
for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]);
}
}
template <typename T>
struct PerThreadData {
thrust::device_vector<T> send_buff;
thrust::device_vector<T> recv_buff;
CUDADeviceContext dev_ctx;
T* SendBuff() { return thrust::raw_pointer_cast(send_buff.data()); }
T* RecvBuff() { return thrust::raw_pointer_cast(recv_buff.data()); }
PerThreadData(int gpu_id, size_t size) : dev_ctx(GPUPlace(gpu_id)) {
send_buff.resize(size);
for (size_t i = 0; i < size; ++i) {
send_buff[i] = static_cast<T>(i);
}
recv_buff.resize(size);
}
};
static constexpr int ELEM_COUNT = 10000;
TEST(NCCL, all_reduce) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
VLOG(1) << "Initializing ncclComm";
auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
PADDLE_ENFORCE(status);
VLOG(1) << "ncclComm initialized";
VLOG(1) << "Creating thread data";
std::vector<std::unique_ptr<PerThreadData<double>>> data;
data.reserve(dev_count);
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Creating thread data for device " << i;
SetDeviceId(i);
data.emplace_back(new PerThreadData<double>(i, ELEM_COUNT));
}
VLOG(1) << "Thread data created";
VLOG(1) << "Check send_buf data";
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Check on device " << i;
SetDeviceId(i);
thrust::host_vector<double> tmp = data[i]->send_buff;
for (size_t j = 0; j < tmp.size(); ++j) {
ASSERT_NEAR(static_cast<double>(j), tmp[j], 1e-5);
}
}
VLOG(1) << "Invoking ncclAllReduce";
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Invoking ncclAllReduce with device " << i;
SetDeviceId(i);
PADDLE_ENFORCE(dynload::ncclAllReduce(
data[i]->SendBuff(), data[i]->RecvBuff(), ELEM_COUNT, ncclDouble,
ncclSum, comms[i], data[i]->dev_ctx.stream()));
VLOG(1) << "Invoked ncclAllReduce for device " << i;
}
VLOG(1) << "Invoked ncclAllReduce";
VLOG(1) << "Sync devices";
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Sync device " << i;
SetDeviceId(i);
data[i]->dev_ctx.Wait();
}
VLOG(1) << "device synced";
for (int i = 0; i < dev_count; ++i) {
SetDeviceId(i);
VLOG(1) << "Checking vector on device " << i;
thrust::host_vector<double> tmp = data[i]->recv_buff;
for (size_t j = 0; j < tmp.size(); ++j) {
auto elem = static_cast<double>(j);
elem *= dev_count;
ASSERT_NEAR(tmp[j], elem, 1e-4);
}
}
for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]);
}
}
} // namespace platform
} // namespace paddle
int main(int argc, char** argv) {
dev_count = paddle::platform::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING)
<< "Cannot test multi-gpu nccl, because the CUDA device count is "
<< dev_count;
return 0;
}
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -35,6 +35,7 @@ struct GPUPlace {
GPUPlace() : GPUPlace(0) {}
explicit GPUPlace(int d) : device(d) {}
inline int GetDeviceId() const { return device; }
// needed for variant equality comparison
inline bool operator==(const GPUPlace &o) const { return device == o.device; }
inline bool operator!=(const GPUPlace &o) const { return !(*this == o); }
......
......@@ -4,3 +4,5 @@ if(WITH_PYTHON)
DEPS pybind python backward proto_desc tensor_array paddle_memory executor
${GLOB_OP_LIB})
endif(WITH_PYTHON)
cc_binary(print_operators_doc SRCS print_operators_doc.cc DEPS ${GLOB_OP_LIB} tensor_array)
#include <iostream>
#include <sstream> // std::stringstream
#include <string>
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"
#include "paddle/pybind/pybind.h"
std::string Escape(const std::string& s) {
std::string r;
for (size_t i = 0; i < s.size(); i++) {
switch (s[i]) {
case '\"':
r += "\\\"";
break;
case '\\':
r += "\\\\";
break;
case '\n':
r += "\\n";
break;
case '\t':
r += "\\t";
case '\r':
break;
default:
r += s[i];
break;
}
}
return r;
}
std::string AttrType(paddle::framework::AttrType at) {
switch (at) {
case paddle::framework::INT:
return "int";
case paddle::framework::FLOAT:
return "float";
case paddle::framework::STRING:
return "string";
case paddle::framework::BOOLEAN:
return "bool";
case paddle::framework::INTS:
return "int array";
case paddle::framework::FLOATS:
return "float array";
case paddle::framework::STRINGS:
return "string array";
case paddle::framework::BOOLEANS:
return "bool array";
case paddle::framework::BLOCK:
return "block id";
}
return "UNKNOWN"; // not possible
}
void PrintVar(const paddle::framework::OpProto::Var& v, std::stringstream& ss) {
ss << " { "
<< "\n"
<< " \"name\" : \"" << Escape(v.name()) << "\",\n"
<< " \"comment\" : \"" << Escape(v.comment()) << "\",\n"
<< " \"duplicable\" : " << v.duplicable() << ",\n"
<< " \"intermediate\" : " << v.intermediate() << "\n"
<< " },";
}
void PrintAttr(const paddle::framework::OpProto::Attr& a,
std::stringstream& ss) {
ss << " { "
<< "\n"
<< " \"name\" : \"" << Escape(a.name()) << "\",\n"
<< " \"type\" : \"" << AttrType(a.type()) << "\",\n"
<< " \"comment\" : \"" << Escape(a.comment()) << "\",\n"
<< " \"generated\" : " << a.generated() << "\n"
<< " },";
}
void PrintOpProto(const std::string& type,
const paddle::framework::OpInfo& opinfo,
std::stringstream& ss) {
std::cerr << "Processing " << type << "\n";
const paddle::framework::OpProto* p = opinfo.proto_;
if (p == nullptr) {
return; // It is possible that an operator doesn't have OpProto.
}
ss << "{\n"
<< " \"type\" : \"" << Escape(p->type()) << "\",\n"
<< " \"comment\" : \"" << Escape(p->comment()) << "\",\n";
ss << " \"inputs\" : [ "
<< "\n";
for (int i = 0; i < p->inputs_size(); i++) {
PrintVar(p->inputs(i), ss);
}
ss.seekp(-1, ss.cur); // remove the trailing comma
ss << " ], "
<< "\n";
ss << " \"outputs\" : [ "
<< "\n";
for (int i = 0; i < p->outputs_size(); i++) {
PrintVar(p->outputs(i), ss);
}
ss.seekp(-1, ss.cur); // remove the trailing comma
ss << " ], "
<< "\n";
ss << " \"attrs\" : [ "
<< "\n";
for (int i = 0; i < p->attrs_size(); i++) {
PrintAttr(p->attrs(i), ss);
}
ss.seekp(-1, ss.cur); // remove the trailing comma
ss << " ] "
<< "\n";
ss << "},";
}
int main() {
std::stringstream ss;
ss << "[\n";
for (auto& iter : paddle::framework::OpInfoMap::Instance().map()) {
PrintOpProto(iter.first, iter.second, ss);
}
ss.seekp(-1, ss.cur); // remove the trailing comma
ss << "]\n";
std::cout << ss.str();
}
......@@ -257,6 +257,7 @@ void BindOpDesc(py::module &m) {
.def("block_attr", &OpDescBind::GetBlockAttr)
.def("check_attrs", &OpDescBind::CheckAttrs)
.def("infer_shape", &OpDescBind::InferShape)
.def("infer_var_type", &OpDescBind::InferVarType)
.def("serialize_to_string", [](OpDescBind &op_desc) -> py::bytes {
const OpDesc *desc = op_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
......
......@@ -225,15 +225,16 @@ All parameter, weight, gradient are variables in Paddle.
//! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
std::vector<py::bytes> ret_values;
OpInfoMap::Instance().IterAllInfo([&ret_values](const std::string &type,
const OpInfo &info) {
if (!info.HasOpProtoAndChecker()) return;
for (auto &iter : OpInfoMap::Instance().map()) {
auto &info = iter.second;
if (info.HasOpProtoAndChecker()) {
std::string str;
PADDLE_ENFORCE(info.Proto().SerializeToString(&str),
PADDLE_ENFORCE(
info.Proto().SerializeToString(&str),
"Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.emplace_back(str);
});
}
}
return ret_values;
});
m.def_submodule(
......
......@@ -1457,11 +1457,13 @@ def dot_product_attention(encoded_sequence,
expanded = expand_layer(
input=transformed_state,
expanded_as=encoded_sequence,
expand_as=encoded_sequence,
name='%s_expand' % name)
m = linear_comb_layer(
weights=expanded, vectors=encoded_sequence, name='%s_dot-product')
weights=expanded,
vectors=encoded_sequence,
name='%s_dot-product' % name)
attention_weight = fc_layer(
input=m,
......
......@@ -53,8 +53,8 @@ class Variable(object):
if is_new_var:
self.desc.set_data_type(dtype)
else:
old_dtype = self.data_type()
if dtype != old_shape:
old_dtype = self.data_type
if dtype != old_dtype:
raise ValueError("Variable {0} has been created before. "
"The previous data type is {1}; the new "
"data type is {2}. They are not "
......@@ -113,6 +113,10 @@ class Variable(object):
def lod_level(self):
return self.desc.lod_level()
@property
def type(self):
return self.desc.type()
@staticmethod
def _unique_var_name_():
uid = core.unique_integer() # unique during whole process.
......@@ -192,31 +196,32 @@ class Operator(object):
self.desc.set_type(type)
proto = OpProtoHolder.instance().get_op_proto(type)
if inputs is not None:
given = set()
need = set()
for n in inputs:
given.add(n)
for m in proto.inputs:
need.add(m.name)
if not given == need:
raise ValueError(
"Incorrect setting for input(s) of operator \"%s\". Need: [%s] Given: [%s]"
% (type, ", ".join(str(e) for e in need), ", ".join(
str(e) for e in given)))
def find_name(var_list, name):
for var_name in var_list:
if var_name == name:
return True
return False
if inputs is not None:
for in_proto in proto.inputs:
found = find_name(inputs, in_proto.name)
assert found or in_proto.dispensable, "Input {} not found".format(
in_proto.name)
if found:
in_argus = inputs[in_proto.name]
if not isinstance(in_argus, list):
in_argus = [in_argus]
if not in_proto.duplicable and len(in_argus) > 1:
raise ValueError(
"Input %s expects only one input, but %d are given." %
(in_proto.name, len(in_argus)))
"Input %s expects only one input, but %d are given."
% (in_proto.name, len(in_argus)))
in_argu_names = []
for argu in in_argus:
in_argu_names.append(argu.name)
self.desc.set_input(in_proto.name, in_argu_names)
else:
self.desc.set_input(in_proto.name, [])
if outputs is not None:
given = set()
......@@ -250,13 +255,14 @@ class Operator(object):
attr_name = attr.name
if (not attr_name in attrs) or (attrs[attr_name] is None):
continue
if not isinstance(attrs[attr_name], Block):
self.desc.set_attr(attr_name, attrs[attr_name])
else:
if isinstance(attrs[attr_name], Block):
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
else:
self.desc.set_attr(attr_name, attrs[attr_name])
self.desc.check_attrs()
if type not in {'feed', 'fetch'}:
self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc)
def __str__(self):
......
from paddle.v2.framework.framework import Variable, OpProtoHolder, g_program, g_init_program
import paddle.v2.framework.core as core
import copy
import itertools
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Variable, g_program, \
g_init_program
def unique_name(prefix):
uid = core.unique_integer() # unique during whole process.
......@@ -120,10 +123,7 @@ class LayerHelper(object):
if attr['name'] is None:
attr['name'] = unique_name(".".join([self.name, suffix]))
self.init_program.global_block().create_parameter(
name=attr['name'],
dtype=dtype,
shape=shape,
init_attr=attr['init_attr'])
dtype=dtype, shape=shape, **attr)
return self.program.global_block().create_parameter(
name=attr['name'], dtype=dtype, shape=shape)
......@@ -133,6 +133,9 @@ class LayerHelper(object):
dtype=dtype,
persistable=False)
def create_variable(self, *args, **kwargs):
return self.program.current_block().create_var(*args, **kwargs)
def create_global_variable(self, *args, **kwargs):
return self.program.global_block().create_var(
*args, persistable=False, **kwargs)
......
from paddle.v2.framework.layer_helper import LayerHelper
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program
import re
__all__ = ['fc', 'data', 'cross_entropy', 'conv2d', 'pool2d']
__all__ = [
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
'StaticRNN'
]
def fc(input,
......@@ -24,7 +27,9 @@ def fc(input,
mul_results = []
for input_var, param_attr in helper.iter_inputs_and_params():
input_shape = input_var.shape
param_shape = list(input_shape[num_flatten_dims:]) + [size]
param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size]
w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype)
......@@ -36,10 +41,8 @@ def fc(input,
"Y": w,
},
outputs={"Out": tmp},
attrs={
'x_num_col_dims': num_flatten_dims,
'y_num_col_dims': len(input_shape) - num_flatten_dims
})
attrs={'x_num_col_dims': num_flatten_dims,
'y_num_col_dims': 1})
mul_results.append(tmp)
# sum
......@@ -55,6 +58,24 @@ def fc(input,
return helper.append_activation(pre_activation)
def embedding(input,
size,
data_type='float32',
param_attr=None,
program=None,
init_program=None):
helper = LayerHelper('embedding', **locals())
w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=data_type)
tmp = helper.create_tmp_variable(data_type)
helper.append_op(
type='lookup_table',
inputs={'Ids': input,
'W': w},
outputs={'Out': tmp})
return tmp
def data(name,
shape,
data_type='float32',
......@@ -122,6 +143,19 @@ _create_op_func_('mean')
_create_op_func_('mul')
def concat(input, axis, program=None, init_program=None):
helper = LayerHelper('concat', **locals())
if not isinstance(input, list) and not isinstance(input, tuple):
input = [input]
out = helper.create_tmp_variable(dtype=input[0].data_type)
helper.append_op(
type='concat',
inputs={'X': input},
outputs={'Out': [out]},
attrs={'axis': axis})
return out
def cross_entropy(input, label, **kwargs):
helper = LayerHelper('cross_entropy', **kwargs)
out = helper.create_tmp_variable(dtype=input.data_type)
......@@ -240,3 +274,170 @@ def pool2d(input,
})
return pool_out
class BlockGuard(object):
"""
BlockGuard used to create sub-block in program by using Python `with`
keyword.
"""
def __init__(self, program):
if not isinstance(program, Program):
raise TypeError("BlockGuard takes a program")
self.program = program
def __enter__(self):
self.program.create_block()
def __exit__(self, exc_type, exc_val, exc_tb):
self.program.rollback()
if exc_type is not None:
return False # re-raise exception
return True
class StaticRNNGuard(BlockGuard):
def __init__(self, rnn):
if not isinstance(rnn, StaticRNN):
raise TypeError("StaticRNNGuard takes an StaticRNN")
super(StaticRNNGuard, self).__init__(rnn.helper.program)
self.rnn = rnn
def __enter__(self):
self.rnn.status = StaticRNN.IN_RNN_BLOCK
return super(StaticRNNGuard, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
self.rnn.complete_rnn_op()
return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb)
class StaticRNNMemoryLink(object):
"""
:param init: the initial variable for Memory
:type init: Variable
:param pre_mem: the memory variable in previous time step
:type pre_mem: Variable
:param mem: the memory variable in current time step
:type mem: Variable
"""
def __init__(self, init, pre_mem, mem=None):
self.init = init
self.pre_mem = pre_mem
self.mem = mem
class StaticRNN(object):
BEFORE_RNN_BLOCK = 0
IN_RNN_BLOCK = 1
AFTER_RNN_BLOCK = 2
def __init__(self, name=None, program=None):
self.helper = LayerHelper("static_rnn", name=name, program=program)
self.memories = {} # memory map, from pre_mem.name --> MemoryLink
self.inputs = [] # input variable list in current block
self.outputs = [] # output variable list in parent block
self.status = StaticRNN.BEFORE_RNN_BLOCK # status flag.
# sequence length, since it is a static RNN, sequence length are fixed.
self.seq_len = None
def step(self):
return StaticRNNGuard(self)
def _assert_in_rnn_block_(self, method):
if self.status != StaticRNN.IN_RNN_BLOCK:
raise ValueError("You must invoke {0} in rnn block".format(method))
def memory(self, init=None, shape=None, dtype=None, init_value=0):
self._assert_in_rnn_block_('memory')
if init is None:
if shape is None or dtype is None:
raise ValueError(
"if init is None, memory at least need shape and dtype")
parent_block = self.parent_block()
var_name = unique_name("@".join([self.helper.name, "memory_boot"]))
boot_var = parent_block.create_var(
name=var_name, shape=shape, dtype=dtype, persistable=False)
parent_block.append_op(
type="fill_constant",
inputs={},
outputs={'Out': [boot_var]},
attrs={
'value': init_value,
'shape': boot_var.shape,
'data_type': boot_var.data_type
})
return self.memory(init=boot_var)
else:
pre_mem = self.helper.create_variable(
name=unique_name("@".join([self.helper.name, "mem"])),
dtype=init.data_type,
shape=init.shape)
self.memories[pre_mem.name] = StaticRNNMemoryLink(
init=init, pre_mem=pre_mem)
return pre_mem
def step_input(self, x):
self._assert_in_rnn_block_('step_input')
if not isinstance(x, Variable):
raise TypeError("step input takes a Variable")
if self.seq_len is None:
self.seq_len = x.shape[1]
elif self.seq_len != x.shape[1]:
raise ValueError("Static RNN only take fix seq_len input")
ipt = self.helper.create_variable(
name=x.name,
dtype=x.data_type,
shape=[-1] + list(x.shape[2:]),
type=x.type)
self.inputs.append(ipt)
return ipt
def step_output(self, o):
self._assert_in_rnn_block_('step_output')
if not isinstance(o, Variable):
raise TypeError("step output takes a Variable")
out_var = self.parent_block().create_var(
name=o.name,
shape=[-1, self.seq_len] + list(o.shape[1:]),
dtype=o.data_type)
self.outputs.append(out_var)
def output(self, *outputs):
for each in outputs:
self.step_output(each)
def update_memory(self, mem, var):
if not isinstance(mem, Variable) or not isinstance(var, Variable):
raise TypeError("update memory should take variables")
self.memories[mem.name].mem = var
def parent_block(self):
prog = self.helper.program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0
parent_block = prog.block(parent_idx)
return parent_block
def __call__(self, *args, **kwargs):
if self.status != StaticRNN.AFTER_RNN_BLOCK:
raise ValueError("RNN output can only be retrieved after rnn block")
if len(self.outputs) == 0:
raise ValueError("RNN has no output")
elif len(self.outputs) == 1:
return self.outputs[0]
else:
return self.outputs
def complete_rnn_op(self):
# TODO(yuyang18): Create RNN Op here.
# Implement this method after RNN op complete.
pass
......@@ -4,6 +4,8 @@ import random
import itertools
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import Program, OpProtoHolder
def grad_var_name(var_name):
......@@ -177,7 +179,12 @@ def get_backward_op(scope, op, no_grad_set):
return backward_op
def get_gradient(scope, op, inputs, outputs, grad_name, place,
def get_gradient(scope,
op,
inputs,
outputs,
grad_names,
place,
no_grad_set=None):
ctx = core.DeviceContext.create(place)
......@@ -193,8 +200,52 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place,
backward_op.run(scope, ctx)
out = np.array(scope.find_var(grad_name).get_tensor())
return out
return [
np.array(scope.find_var(grad_name).get_tensor())
for grad_name in grad_names
]
def append_input_output(block, op_proto, np_list, is_input):
'''Insert VarDesc and generate Python variable instance'''
proto_list = op_proto.inputs if is_input else op_proto.outputs
def create_var(block, name, np_list, var_proto):
if name not in np_list:
assert var_proto.intermediate, "{} not found".format(name)
shape = None
lod_level = None
else:
np_value = np_list[name]
if isinstance(np_value, tuple):
shape = list(np_value[0].shape)
lod_level = len(np_value[1])
else:
shape = list(np_value.shape)
lod_level = 0
return block.create_var(
dtype="float32", shape=shape, lod_level=lod_level, name=name)
var_dict = {}
for var_proto in proto_list:
var_name = str(var_proto.name)
if is_input:
if (var_name not in np_list) and var_proto.dispensable:
continue
assert (var_name in np_list) or (var_proto.dispensable), \
"Missing {} as input".format(var_name)
if var_proto.duplicable:
assert isinstance(np_list[var_name], list), \
"Duplicable {} should be set as list".format(var_name)
var_list = []
for (name, np_value) in np_list[var_name]:
var_list.append(
create_var(block, name, {name: np_value}, var_proto))
var_dict[var_name] = var_list
else:
var_dict[var_name] = create_var(block, var_name, np_list, var_proto)
return var_dict
class OpTest(unittest.TestCase):
......@@ -213,48 +264,93 @@ class OpTest(unittest.TestCase):
np.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)
def feed_var(self, input_vars, place):
feed_map = {}
for var_name in input_vars:
if isinstance(input_vars[var_name], list):
for name, np_value in self.inputs[var_name]:
tensor = core.LoDTensor()
tensor.set(np_value, place)
feed_map[name] = tensor
else:
tensor = core.LoDTensor()
if isinstance(self.inputs[var_name], tuple):
tensor.set(self.inputs[var_name][0], place)
tensor.set_lod(self.inputs[var_name][1])
else:
tensor.set(self.inputs[var_name], place)
feed_map[var_name] = tensor
return feed_map
def check_output_with_place(self, place, atol):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
op_attrs)
if isinstance(place, core.GPUPlace) and not self.op.support_gpu():
return
set_input(self.scope, self.op, self.inputs, place)
ctx = core.DeviceContext.create(place)
self.op.run(self.scope, ctx)
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
program = Program()
block = program.global_block()
inputs = append_input_output(block, op_proto, self.inputs, True)
outputs = append_input_output(block, op_proto, self.outputs, False)
op = block.append_op(
type=self.op_type,
inputs=inputs,
outputs=outputs,
attrs=self.attrs if hasattr(self, "attrs") else dict())
fetch_list = []
for var_name, var in outputs.iteritems():
if var_name in self.outputs:
if isinstance(var, list):
for v in var:
fetch_list.append(v)
else:
fetch_list.append(var)
feed_map = self.feed_var(inputs, place)
for out_name, out_dup in Operator.get_op_outputs(self.op.type()):
exe = Executor(place)
outs = exe.run(program, feed=feed_map, fetch_list=fetch_list)
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs:
continue
def find_actual(target_name, fetch_list):
found = [
i for i, var in enumerate(fetch_list)
if var.name == target_name
]
self.assertTrue(
len(found) == 1, "Found {} {}".format(
len(found), target_name))
return found[0]
if out_dup:
sub_out = self.outputs[out_name]
if not isinstance(sub_out, list):
raise AssertionError("sub_out type %s is not list",
type(sub_out))
for sub_out_name, expect in sub_out:
actual = np.array(
self.scope.find_var(sub_out_name).get_tensor())
idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx]
self.assertTrue(
np.allclose(
actual, expect, atol=atol),
"output name: " + out_name + " has diff.")
"Output (" + sub_out_name + ") has diff at " +
str(place))
else:
actual = np.array(self.scope.find_var(out_name).get_tensor())
idx = find_actual(out_name, fetch_list)
actual = outs[idx]
expect = self.outputs[out_name]
self.assertTrue(
np.allclose(
actual, expect, atol=atol),
"output name: " + out_name + " has diff.")
"Output (" + out_name + ") has diff at " + str(place))
def check_output(self, atol=1e-5):
places = [core.CPUPlace()]
if core.is_compile_gpu():
if core.is_compile_gpu() and core.op_support_gpu(self.op_type):
places.append(core.GPUPlace(0))
for place in places:
self.check_output_with_place(place, atol)
......@@ -310,11 +406,9 @@ class OpTest(unittest.TestCase):
]
cpu_place = core.CPUPlace()
cpu_analytic_grads = [
get_gradient(self.scope, self.op, self.inputs, self.outputs,
grad_name, cpu_place, no_grad_set)
for grad_name in grad_names
]
cpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
self.outputs, grad_names, cpu_place,
no_grad_set)
self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names,
max_relative_error,
......@@ -322,11 +416,9 @@ class OpTest(unittest.TestCase):
if core.is_compile_gpu() and self.op.support_gpu():
gpu_place = core.GPUPlace(0)
gpu_analytic_grads = [
get_gradient(self.scope, self.op, self.inputs, self.outputs,
grad_name, gpu_place, no_grad_set)
for grad_name in grad_names
]
gpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
self.outputs, grad_names,
gpu_place, no_grad_set)
self.__assert_is_close(numeric_grads, gpu_analytic_grads,
grad_names, max_relative_error,
......
......@@ -16,7 +16,9 @@ class TestAccuracyOp(OpTest):
if ele == label[rowid]:
num_correct += 1
break
self.outputs = {'Accuracy': [num_correct / float(n)]}
self.outputs = {
'Accuracy': np.array([num_correct / float(n)]).astype("float32")
}
def test_check_output(self):
self.check_output()
......
......@@ -172,8 +172,8 @@ class TestBRelu(OpTest):
def setUp(self):
self.op_type = "brelu"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
t_min = 1
t_max = 4
t_min = 1.0
t_max = 4.0
# The same with TestAbs
x[np.abs(x - t_min) < 0.005] = t_min + 0.02
x[np.abs(x - t_max) < 0.005] = t_max + 0.02
......@@ -218,7 +218,7 @@ class TestSoftRelu(OpTest):
def setUp(self):
self.op_type = "soft_relu"
x = np.random.uniform(-3, 3, [4, 4]).astype("float32")
threshold = 2
threshold = 2.0
# The same reason with TestAbs
x[np.abs(x - threshold) < 0.005] = threshold + 0.02
x[np.abs(x + threshold) < 0.005] = -threshold + 0.02
......@@ -303,7 +303,7 @@ class TestPow(OpTest):
def setUp(self):
self.op_type = "pow"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.attrs = {'factor': 3}
self.attrs = {'factor': 3.0}
self.outputs = {'Y': np.power(self.inputs['X'], 3)}
def test_check_output(self):
......
......@@ -37,14 +37,14 @@ class TestCase1(TestClipOp):
def initTestCase(self):
self.shape = (8, 16, 8)
self.max = 0.7
self.min = 0
self.min = 0.0
class TestCase2(TestClipOp):
def initTestCase(self):
self.shape = (8, 16)
self.max = 1
self.min = 0
self.max = 1.0
self.min = 0.0
class TestCase3(TestClipOp):
......
import unittest
import numpy as np
from op_test import OpTest
class TestIdentityOp(OpTest):
def setUp(self):
self.op_type = "identity"
self.inputs = {'X': np.random.random((10, 10)).astype("float32")}
self.outputs = {'Y': self.inputs['X']}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
if __name__ == "__main__":
unittest.main()
......@@ -8,7 +8,8 @@ class TestLookupTableOp(OpTest):
self.op_type = "lookup_table"
table = np.random.random((17, 31)).astype("float32")
ids = np.random.randint(0, 17, 4).astype("int32")
self.inputs = {'W': table, 'Ids': ids}
ids_expand = np.expand_dims(ids, axis=1)
self.inputs = {'W': table, 'Ids': ids_expand}
self.outputs = {'Out': table[ids]}
def test_check_output(self):
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -5,3 +5,4 @@ py_test(test_topology SRCS test_topology.py)
py_test(test_rnn_layer SRCS test_rnn_layer.py)
py_test(test_parameters SRCS test_parameters.py)
py_test(test_data_feeder SRCS test_data_feeder.py)
py_test(test_paramconf_order SRCS test_paramconf_order.py)
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册