提交 bc3ec536 编写于 作者: C chengduoZH

remove conflict

......@@ -108,14 +108,11 @@ else()
set(THIRD_PARTY_BUILD_TYPE Release)
endif()
if(WITH_MKL)
set(WITH_MKLML ON)
set(WITH_MKLDNN ${AVX2_FOUND})
if(NOT WITH_MKLDNN)
message(WARNING "Do not have AVX2 intrinsics and disabled MKL-DNN")
endif()
set(WITH_MKLML ${WITH_MKL})
if (WITH_MKL AND AVX2_FOUND)
set(WITH_MKLDNN ON)
else()
set(WITH_MKLML OFF)
message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN")
set(WITH_MKLDNN OFF)
endif()
......@@ -166,10 +163,7 @@ set(EXTERNAL_LIBS
)
if(WITH_GPU)
list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
if(NOT WITH_DSO)
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY} ${NCCL_LIBRARY})
endif(NOT WITH_DSO)
include(cuda)
endif(WITH_GPU)
if(WITH_MKLML)
......
if(NOT WITH_GPU)
return()
endif()
set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs7 "30 35 50 52")
set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
######################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
# Usage:
# detect_installed_gpus(out_variable)
function(detect_installed_gpus out_variable)
if(NOT CUDA_gpu_detect_output)
set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu)
file(WRITE ${cufile} ""
"#include <cstdio>\n"
"int main() {\n"
" int count = 0;\n"
" if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
" if (count == 0) return -1;\n"
" for (int device = 0; device < count; ++device) {\n"
" cudaDeviceProp prop;\n"
" if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
" std::printf(\"%d.%d \", prop.major, prop.minor);\n"
" }\n"
" return 0;\n"
"}\n")
execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "-ccbin=${CUDA_HOST_COMPILER}"
"--run" "${cufile}"
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/"
RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
if(nvcc_res EQUAL 0)
# only keep the last line of nvcc_out
STRING(REGEX REPLACE ";" "\\\\;" nvcc_out "${nvcc_out}")
STRING(REGEX REPLACE "\n" ";" nvcc_out "${nvcc_out}")
list(GET nvcc_out -1 nvcc_out)
string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}")
set(CUDA_gpu_detect_output ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_installed_gpus tool" FORCE)
endif()
endif()
if(NOT CUDA_gpu_detect_output)
message(STATUS "Automatic GPU detection failed. Building for all known architectures.")
set(${out_variable} ${paddle_known_gpu_archs} PARENT_SCOPE)
else()
set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE)
endif()
endfunction()
########################################################################
# Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME
# Usage:
# select_nvcc_arch_flags(out_variable)
function(select_nvcc_arch_flags out_variable)
# List of arch names
set(archs_names "Kepler" "Maxwell" "Pascal" "All" "Manual")
set(archs_name_default "All")
if(NOT CMAKE_CROSSCOMPILING)
list(APPEND archs_names "Auto")
endif()
# set CUDA_ARCH_NAME strings (so it will be seen as dropbox in CMake-Gui)
set(CUDA_ARCH_NAME ${archs_name_default} CACHE STRING "Select target NVIDIA GPU achitecture.")
set_property( CACHE CUDA_ARCH_NAME PROPERTY STRINGS "" ${archs_names} )
mark_as_advanced(CUDA_ARCH_NAME)
# verify CUDA_ARCH_NAME value
if(NOT ";${archs_names};" MATCHES ";${CUDA_ARCH_NAME};")
string(REPLACE ";" ", " archs_names "${archs_names}")
message(FATAL_ERROR "Only ${archs_names} architeture names are supported.")
endif()
if(${CUDA_ARCH_NAME} STREQUAL "Manual")
set(CUDA_ARCH_BIN ${paddle_known_gpu_archs} CACHE STRING "Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported")
set(CUDA_ARCH_PTX "50" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for")
mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX)
else()
unset(CUDA_ARCH_BIN CACHE)
unset(CUDA_ARCH_PTX CACHE)
endif()
if(${CUDA_ARCH_NAME} STREQUAL "Kepler")
set(cuda_arch_bin "30 35")
elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell")
set(cuda_arch_bin "50")
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
set(cuda_arch_bin "60 61")
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
set(cuda_arch_bin "70")
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
set(cuda_arch_bin ${paddle_known_gpu_archs})
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
detect_installed_gpus(cuda_arch_bin)
else() # (${CUDA_ARCH_NAME} STREQUAL "Manual")
set(cuda_arch_bin ${CUDA_ARCH_BIN})
endif()
# remove dots and convert to lists
string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
string(REGEX REPLACE "\\." "" cuda_arch_ptx "${CUDA_ARCH_PTX}")
string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}")
string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}")
list(REMOVE_DUPLICATES cuda_arch_bin)
list(REMOVE_DUPLICATES cuda_arch_ptx)
set(nvcc_flags "")
set(nvcc_archs_readable "")
# Tell NVCC to add binaries for the specified GPUs
foreach(arch ${cuda_arch_bin})
if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
# User explicitly specified PTX for the concrete BIN
list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
else()
# User didn't explicitly specify PTX for the concrete BIN, we assume PTX=BIN
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
list(APPEND nvcc_archs_readable sm_${arch})
endif()
endforeach()
# Tell NVCC to add PTX intermediate code for the specified architectures
foreach(arch ${cuda_arch_ptx})
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
list(APPEND nvcc_archs_readable compute_${arch})
endforeach()
string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
set(${out_variable} ${nvcc_flags} PARENT_SCOPE)
set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
endfunction()
message(STATUS "CUDA detected: " ${CUDA_VERSION})
if (${CUDA_VERSION} LESS 7.0)
set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the
# warning for now.
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
endif()
include_directories(${CUDA_INCLUDE_DIRS})
list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
if(NOT WITH_DSO)
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY} ${NCCL_LIBRARY})
endif(NOT WITH_DSO)
# setting nvcc arch flags
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA_readable}")
# Set C++11 support
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here.
list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC")
# Set :expt-relaxed-constexpr to suppress Eigen warnings
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG})
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL})
endif()
mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION)
......@@ -149,58 +149,3 @@ endforeach()
foreach(flag ${GPU_COMMON_FLAGS})
safe_set_nvflag(${flag})
endforeach()
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here.
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11)
LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math)
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
LIST(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG})
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
LIST(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
LIST(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
LIST(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL})
endif()
function(specify_cuda_arch cuda_version cuda_arch)
if(${cuda_version} VERSION_GREATER "8.0")
foreach(capability 61 62)
if(${cuda_arch} STREQUAL ${capability})
list(APPEND __arch_flags " -gencode arch=compute_${cuda_arch},code=sm_${cuda_arch}")
endif()
endforeach()
elseif(${cuda_version} VERSION_GREATER "7.0" and ${cuda_arch} STREQUAL "53")
list(APPEND __arch_flags " -gencode arch=compute_${cuda_arch},code=sm_${cuda_arch}")
endif()
endfunction()
# Common gpu architectures: Kepler, Maxwell
foreach(capability 30 35 50)
list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}")
endforeach()
if (CUDA_VERSION VERSION_GREATER "7.0" OR CUDA_VERSION VERSION_EQUAL "7.0")
list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52")
endif()
# Modern gpu architectures: Pascal
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
list(APPEND CUDA_NVCC_FLAGS --expt-relaxed-constexpr)
endif()
# Custom gpu architecture
set(CUDA_ARCH)
if(CUDA_ARCH)
specify_cuda_arch(${CUDA_VERSION} ${CUDA_ARCH})
endif()
set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS})
......@@ -335,6 +335,16 @@ bilinear_interp
.. autoclass:: paddle.v2.layer.bilinear_interp
:noindex:
dot_prod
---------
.. autoclass:: paddle.v2.layer.dot_prod
:noindex:
out_prod
--------
.. autoclass:: paddle.v2.layer.out_prod
:noindex:
power
-----
.. autoclass:: paddle.v2.layer.power
......@@ -372,6 +382,11 @@ cos_sim
.. autoclass:: paddle.v2.layer.cos_sim
:noindex:
l2_distance
-----------
.. autoclass:: paddle.v2.layer.l2_distance
:noindex:
trans
-----
.. autoclass:: paddle.v2.layer.trans
......
......@@ -513,19 +513,14 @@ ParamGradInfoMap AppendBackward(
const int root_block_idx = 0;
auto root_block = program_desc.MutableBlock(root_block_idx);
// insert fill one op for target
// TODO(qiao) add some check to the target.
std::string fill_one_op_out = GradVarName(target.Name());
std::vector<int64_t> target_shape_desc = target.Shape();
std::vector<int> target_shape;
std::transform(target_shape_desc.begin(), target_shape_desc.end(),
std::back_inserter(target_shape),
[](int64_t dim) { return static_cast<int>(dim); });
bool is_scalar = target.Shape() == std::vector<int64_t>{1};
PADDLE_ENFORCE(is_scalar, "target should be scalar");
VLOG(3) << "backward from loss=" << target.Name()
<< " data_type=" << target.GetDataType();
std::unique_ptr<OpDescBind> fill_one_op(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", target_shape},
{{"shape", std::vector<int>{1}},
{"value", static_cast<float>(1.0)},
{"data_type", target.GetDataType()}}));
// infer var type of fill_one_op
......
......@@ -508,6 +508,7 @@ TEST(Backward, simple_single_op) {
op->SetOutput("Out", {"out"});
auto target = f::VarDescBind("out");
target.SetShape({1});
auto var_to_grad = AppendBackward(program, target, {});
ASSERT_EQ(block->AllOps().size(), 3UL);
......@@ -544,6 +545,7 @@ TEST(Backward, default_attribute) {
op->CheckAttrs();
auto target = f::VarDescBind("out");
target.SetShape({1});
AppendBackward(program, target, {});
ASSERT_EQ(block->AllOps().size(), 3UL);
......@@ -581,6 +583,7 @@ TEST(Backward, simple_mult_op) {
op3->SetOutput("Out", {"out3"});
auto target = f::VarDescBind("out3");
target.SetShape({1});
size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {});
......@@ -670,6 +673,7 @@ TEST(Backward, intermedia_var_no_grad) {
op4->SetOutput("Out", {"out4"});
auto target = f::VarDescBind("out4");
target.SetShape({1});
size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {"out3"});
......@@ -730,6 +734,7 @@ TEST(Backward, var_no_grad) {
op2->SetOutput("Z", {"z2"});
auto target = f::VarDescBind("z2");
target.SetShape({1});
size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {"z1"});
......@@ -810,6 +815,7 @@ TEST(Backward, shared_var) {
op3->SetOutput("Out", {"out3"});
auto target = f::VarDescBind("out3");
target.SetShape({1});
size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {});
......@@ -888,6 +894,7 @@ TEST(Backward, half_backward) {
op1->SetOutput("Out", {"out"});
auto target = f::VarDescBind("out");
target.SetShape({1});
size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {"b"});
f::OpDescBind *fill_op = block->AllOps()[forward_len];
......
......@@ -46,6 +46,8 @@ inline std::type_index ToTypeIndex(DataType type) {
return typeid(int);
case DataType::INT64:
return typeid(int64_t);
case DataType::BOOL:
return typeid(bool);
default:
PADDLE_THROW("Not support type %d", type);
}
......@@ -66,6 +68,9 @@ inline void VisitDataType(DataType type, Visitor visitor) {
case DataType::INT64:
visitor.template operator()<int64_t>();
break;
case DataType::BOOL:
visitor.template operator()<bool>();
break;
default:
PADDLE_THROW("Not supported");
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
/**
* @brief A layer for computing the dot product of two vectors.
* Input1: vector (batchSize * dim)
* Input2: vector (batchSize * dim)
* Output: a matrix: (batchSize * 1)
*/
class DotProdLayer : public Layer {
public:
explicit DotProdLayer(const LayerConfig& config) : Layer(config) {}
~DotProdLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(dot_prod, DotProdLayer);
bool DotProdLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 2U);
CHECK_EQ(1UL, getSize())
<< "The output dimensionality of this layer should be fixed to 1.";
return true;
}
void DotProdLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
size_t batchSize = inV0->getHeight();
CHECK_EQ(inV1->getHeight(), batchSize);
CHECK_EQ(inV0->getWidth(), inV1->getWidth());
{
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
reserveOutput(batchSize, 1);
}
MatrixPtr outV = getOutputValue();
{
REGISTER_TIMER_INFO("FwDotProdTimer", getName().c_str());
outV->sumOfProducts(*inV0, *inV1, 1, 0);
}
}
void DotProdLayer::backward(const UpdateCallback& callback) {
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
MatrixPtr outG = getOutputGrad();
MatrixPtr inG0 = getInputGrad(0);
MatrixPtr inG1 = getInputGrad(1);
{
REGISTER_TIMER_INFO("BwDotProdTimer", getName().c_str());
if (inG0) {
inG0->addRowScale(0, *inV1, *outG);
}
if (inG1) {
inG1->addRowScale(0, *inV0, *outG);
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "L2DistanceLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(l2_distance, L2DistanceLayer);
bool L2DistanceLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 2UL) << "The L2DistanceLayer accepts two and "
<< "only two inputs.";
CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2DistanceLayer "
<< "is fixed to be 1.";
return true;
}
void L2DistanceLayer::forward(PassType passType) {
Layer::forward(passType);
const auto inV1 = getInputValue(0);
const auto inV2 = getInputValue(1);
CHECK(inV1 && inV2);
CHECK_EQ(inV1->getHeight(), inV2->getHeight())
<< "The height of two inputs of this layer must be the same.";
CHECK_EQ(inV1->getWidth(), inV2->getWidth())
<< "The width of two inputs of this layer must be the same.";
int batchSize = inV1->getHeight();
int output_dim = getSize();
{
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
reserveOutput(batchSize, output_dim);
auto outV = getOutputValue();
CHECK(outV) << "The output matrix should not be null.";
Matrix::resizeOrCreate(
inputSub_, inV1->getHeight(), inV1->getWidth(), false, useGpu_);
inputSub_->assign(*inV1);
inputSub_->sub(*inV2);
outV->sumOfProducts(*inputSub_, *inputSub_, 1, 0);
outV->sqrt2(*outV);
}
}
void L2DistanceLayer::backward(const UpdateCallback& callback) {
const auto outG = getOutputGrad();
const auto outV = getOutputValue();
CHECK(outG && outV);
auto inGrad1 = getInputGrad(0);
auto inGrad2 = getInputGrad(1);
{
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
if (inGrad1 || inGrad2) {
outV->scalarDiv(*outV, 1.);
outV->dotMul(*outG, *outV);
}
if (inGrad1) inGrad1->addRowScale(0, *inputSub_, *outV);
if (inGrad2) {
inputSub_->mulScalar(-1.);
inGrad2->addRowScale(0, *inputSub_, *outV);
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief The layer calculates the l2 distance between two input vectors.
* \f[
* f(\bf{x}, \bf{y}) = \sqrt{\sum_{i=1}^D(x_i - y_i)}
* \f]
*
* - Input1: A vector (batchSize * dataDim)
* - Input2: A vector (batchSize * dataDim)
* - Output: A vector (batchSize * 1)
*
* The configuration api is: l2_distance_layer.
*/
class L2DistanceLayer : public Layer {
public:
explicit L2DistanceLayer(const LayerConfig& config) : Layer(config) {}
~L2DistanceLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
private:
// Store the result of subtracting Input2 from Input1 in forward computation,
// which will be reused in backward computation.
MatrixPtr inputSub_;
};
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MKLDNNConcatLayer.h"
using namespace mkldnn; // NOLINT
typedef memory::format format;
namespace paddle {
REGISTER_LAYER(mkldnn_concat, MKLDNNConcatLayer);
bool MKLDNNConcatLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
return false;
}
CHECK_GT(inputLayers_.size(), 1UL);
CHECK(!biasParameter_);
return true;
}
void MKLDNNConcatLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw);
ic = inputLayers_[0]->getSize() / ih / iw;
CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize());
CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw);
CHECK_GT(inputLayers_.size(), 1UL);
channels_.resize(inputLayers_.size());
channels_[0] = ic;
// need change the output channel, so use oc_ instead
// TODO(TJ): change API, use &oc
oc_ = ic;
for (size_t i = 1; i < inputLayers_.size(); i++) {
int batchsize, height, witdh;
reshapeInput(batchsize, height, witdh, i);
CHECK_EQ(bs, batchsize);
CHECK_EQ(ih, height);
CHECK_EQ(iw, witdh);
channels_[i] = inputLayers_[i]->getSize() / height / witdh;
CHECK_EQ((size_t)channels_[i] * height * witdh, inputLayers_[i]->getSize());
oc_ += channels_[i];
}
oh = ih;
ow = iw;
reshapeOutput(oh, ow);
resizeOutput(bs, oc_ * oh * ow);
}
void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
resetFwdBuffers(inVals_, out);
in = inVals_[0];
std::shared_ptr<concat::primitive_desc> fwdPD;
resetFwdPD(fwdPD, inVals_, out);
resetFwdPipeline(pipeline, fwdPD, inVals_, out);
}
void MKLDNNConcatLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
resetBwdBuffers(inGrads_, out);
in = inGrads_[0];
resetBwdPipeline(pipeline, bwds_, inGrads_, out);
}
void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out) {
inputs.resize(inputLayers_.size());
bool has8c = false, has16c = false, hasnc = false;
for (size_t i = 0; i < inputs.size(); i++) {
// resetInValue will use ic_ so temporary change as current input's channel
// TODO(TJ): change ic_ as vector then can remove channels_
ic_ = channels_[i];
resetInValue(inputs[i], nullptr, i);
CHECK(inputs[i]);
auto dm = inputs[i]->getDims();
// inputs format can be different, but ndims must equal
CHECK(i == 0 || dm.size() == inputs[0]->getDims().size());
CHECK_EQ(bs_, dm[0]);
CHECK_EQ(channels_[i], dm[1]);
if (dm.size() > 2) {
CHECK_EQ(ih_, dm[2]);
CHECK_EQ(iw_, dm[3]);
}
if (inputs[i]->getFormat() == format::nc) {
hasnc = true;
}
if (inputs[i]->getFormat() == format::nChw8c) {
has8c = true;
}
if (inputs[i]->getFormat() == format::nChw16c) {
has16c = true;
}
}
// change back, ic_ always save the input 0 size
ic_ = channels_[0];
format outFmt;
if (has16c && oc_ % 16 == 0) {
outFmt = format::nChw16c;
} else if (has8c && oc_ % 8 == 0) {
outFmt = format::nChw8c;
} else if (hasnc) {
CHECK(oh_ == 1 && ow_ == 1);
outFmt = format::nc;
} else {
outFmt = format::nchw;
}
memory::dims outDims =
hasnc ? memory::dims{bs_, oc_} : memory::dims{bs_, oc_, oh_, ow_};
auto outPD = MKLDNNMatrix::createPrimitiveDesc(outDims, outFmt, engine_);
resetOutValue(out, outPD);
}
void MKLDNNConcatLayer::resetFwdPD(std::shared_ptr<concat::primitive_desc>& pd,
std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr out) {
std::vector<memory::primitive_desc> srcPDs;
for (size_t i = 0; i < inputs.size(); i++) {
srcPDs.push_back(inputs[i]->getPrimitiveDesc());
}
CHECK(out);
pd.reset(new concat::primitive_desc(out->getMemoryDesc(), axis_, srcPDs));
CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc());
}
void MKLDNNConcatLayer::resetFwdPipeline(
std::vector<primitive>& pipeline,
std::shared_ptr<concat::primitive_desc>& pd,
std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out) {
std::vector<primitive::at> srcs;
for (size_t i = 0; i < inputs.size(); i++) {
srcs.push_back(*(inputs[i]));
}
fwd_.reset(new concat(*pd, srcs, *out));
pipeline.push_back(*fwd_);
}
void MKLDNNConcatLayer::resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out) {
CHECK(outVal_);
resetOutGrad(out, outVal_->getPrimitiveDesc());
CHECK(out);
inputs.resize(inputLayers_.size());
for (size_t i = 0; i < inputs.size(); i++) {
CHECK(inVals_[i]);
// resetInGrad will use inVal_
// TODO(TJ): change move inVals_ to MKLDNNLayer ans remove inVal_
inVal_ = inVals_[i];
resetInGrad(inputs[i], inVals_[i]->getPrimitiveDesc(), i);
CHECK_PRIMITIVE_DESC_EQ(inputs[i], inVals_[i]->getPrimitiveDesc());
}
// change back, inVal_ always save the input 0
inVal_ = inVals_[0];
}
void MKLDNNConcatLayer::resetBwdPipeline(
std::vector<mkldnn::primitive>& pipeline,
std::vector<std::shared_ptr<mkldnn::primitive>>& prims,
std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out) {
// reset the backward primitives
memory::dims offsets = {0, 0, 0, 0};
prims.resize(inputs.size());
CHECK_EQ(inputs.size(), channels_.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto viewPD = view::primitive_desc(
out->getPrimitiveDesc(), inputs[i]->getDims(), offsets);
auto bwdPD = reorder::primitive_desc(viewPD.dst_primitive_desc(),
inputs[i]->getPrimitiveDesc());
prims[i].reset(new reorder(bwdPD, *out, *(inputs[i])));
offsets[axis_] += channels_[i];
// push to pipeline
pipeline.push_back(*prims[i]);
}
}
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "MKLDNNLayer.h"
#include "mkldnn.hpp"
namespace paddle {
/**
* @brief A subclass of MKLDNNLayer Concatenate layer.
*
* The config file api is mkldnn_concat
*/
class MKLDNNConcatLayer : public MKLDNNLayer {
protected:
std::vector<MKLDNNMatrixPtr> inVals_;
std::vector<MKLDNNMatrixPtr> inGrads_;
std::vector<std::shared_ptr<mkldnn::primitive>> bwds_;
// input channel numbers
std::vector<int> channels_;
// concat_dimension in MKLDNN
// if axis_ == 0, concat batchsize
// if axis_ == 1, concat channel (default)
int axis_;
public:
explicit MKLDNNConcatLayer(const LayerConfig& config)
: MKLDNNLayer(config), axis_(1) {}
~MKLDNNConcatLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override;
void printSizeInfo() override {
CHECK_EQ(channels_.size(), inputLayers_.size());
for (size_t i = 0; i < channels_.size(); ++i) {
VLOG(MKLDNN_SIZES) << "Input " << i << ", " << inputLayers_[i]->getName()
<< ": " << bs_ << ", " << channels_[i] << ", " << ih_
<< ", " << iw_;
}
VLOG(MKLDNN_SIZES) << "Output: " << bs_ << ", " << oc_ << ", " << oh_
<< ", " << ow_;
}
void printValueFormat() override {
for (size_t i = 0; i < inVals_.size(); ++i) {
VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName()
<< ": " << inVals_[i]->getFormat() << " >>>";
}
if (outVal_) {
VLOG(MKLDNN_FMTS) << outVal_->getFormat() << " >>> ";
}
if (extOutVal_) {
VLOG(MKLDNN_FMTS) << extOutVal_->getFormat();
}
}
void printGradFormat() override {
if (extOutGrad_) {
VLOG(MKLDNN_FMTS) << extOutGrad_->getFormat();
}
if (outGrad_) {
VLOG(MKLDNN_FMTS) << outGrad_->getFormat() << " <<< ";
}
for (size_t i = 0; i < inGrads_.size(); ++i) {
VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName()
<< ": " << inGrads_[i]->getFormat() << "<<<";
}
}
protected:
/**
* Forward functions: reset buffers(inputs, output, bias),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out);
void resetFwdPD(std::shared_ptr<mkldnn::concat::primitive_desc>& pd,
std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr out);
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<mkldnn::concat::primitive_desc>& pd,
std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(inputs, output, bias)
* reset primitives and pipeline
*/
void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out);
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::vector<std::shared_ptr<mkldnn::primitive>>& prims,
std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out);
};
} // namespace paddle
......@@ -21,7 +21,7 @@ namespace paddle {
bool MKLDNNLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
CHECK(FLAGS_use_mkldnn) << "MKLDNNLayers only support use_mkldnn."
<< "Please set WITH_MKL=ON "
<< "and set use_mkldnn=True";
CHECK(!useGpu_) << "Do not support GPU yet";
......@@ -138,8 +138,11 @@ void MKLDNNLayer::backward(const UpdateCallback& callback) {
}
}
void MKLDNNLayer::reshapeInput(int& batchsize, int& height, int& width) {
const Argument& input = inputLayers_[0]->getOutput();
void MKLDNNLayer::reshapeInput(int& batchsize,
int& height,
int& width,
size_t inputIdx) {
const Argument& input = inputLayers_[inputIdx]->getOutput();
batchsize = input.getBatchSize();
int h = input.getFrameHeight();
int w = input.getFrameWidth();
......
......@@ -178,7 +178,10 @@ protected:
/**
* reshape the input image sizes and input batchsize
*/
void reshapeInput(int& batchsize, int& height, int& width);
void reshapeInput(int& batchsize,
int& height,
int& width,
size_t inputIdx = 0);
/**
* reshape output image sizes
......
......@@ -29,7 +29,7 @@ gserver_test(test_KmaxSeqScore)
gserver_test(test_Expand)
gserver_test(test_MaxPoolingWithMaskOutput)
########## test_Mkldnn layers and activations ##########
########## test_MKLDNN layers and activations ##########
if(WITH_MKLDNN)
add_unittest_without_exec(test_MKLDNN
test_MKLDNN.cpp
......
......@@ -23,7 +23,7 @@ limitations under the License. */
namespace paddle {
/**
* @brief test the functionality of Mkldnnlayers
* @brief test the functionality of MKLDNNlayers and MKLDNNActivations
* refer to paddle original function
*/
class MKLDNNTester {
......
......@@ -583,6 +583,7 @@ TEST(Layer, maxoutLayer) {
testLayerGrad(config, "maxout", 10, false, useGpu);
}
}
void testFcLayer(string format, size_t nnz) {
TestConfig config;
config.biasSize = 1024;
......@@ -1081,6 +1082,21 @@ TEST(Layer, InterpolationLayer) {
}
}
TEST(Layer, DotProdLayer) {
TestConfig config;
config.layerConfig.set_type("dot_prod");
config.layerConfig.set_size(1);
config.inputDefs.push_back({INPUT_DATA, "layer_0", 10, 0});
config.layerConfig.add_inputs();
config.inputDefs.push_back({INPUT_DATA, "layer_1", 10, 0});
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "dot_prod", 10, false, useGpu);
}
}
TEST(Layer, OuterProdLayer) {
TestConfig config;
config.layerConfig.set_type("out_prod");
......@@ -2429,6 +2445,25 @@ TEST(Layer, ScaleSubRegionLayer) {
}
}
TEST(Layer, L2DistanceLayer) {
TestConfig config;
config.layerConfig.set_type("l2_distance");
config.layerConfig.set_size(1);
config.biasSize = 0;
const size_t input_dim = 27;
const size_t batch_size = 11;
config.inputDefs.push_back({INPUT_DATA, "layer_0", input_dim, 0});
config.inputDefs.push_back({INPUT_DATA, "layer_1", input_dim, 0});
config.layerConfig.add_inputs();
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "l2_distance", batch_size, false, useGpu);
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
......
......@@ -313,6 +313,47 @@ TEST(MKLDNNLayer, AddtoLayer) {
testAddtoLayer({4, 12, 1, 1}, 3);
}
static void getMKLDNNConcatConfig(TestConfig& cfg,
const std::vector<testImageDesc>& inputs) {
CHECK_GE(inputs.size(), 2) << "at least two inputs";
int oc = inputs[0].ic;
for (size_t i = 1; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i].bs, inputs[0].bs);
CHECK_EQ(inputs[i].ih, inputs[0].ih);
CHECK_EQ(inputs[i].iw, inputs[0].iw);
oc += inputs[i].ic;
}
cfg.biasSize = 0;
cfg.layerConfig.set_type("mkldnn_concat");
cfg.layerConfig.set_size(oc * inputs[0].ih * inputs[0].iw);
cfg.layerConfig.set_active_type("relu");
for (size_t i = 0; i < inputs.size(); ++i) {
std::stringstream ss;
ss << "layer_" << i;
cfg.inputDefs.push_back(
{INPUT_DATA,
ss.str(),
(size_t)(inputs[i].ic) * inputs[i].ih * inputs[i].iw,
0});
LayerInputConfig* input = cfg.layerConfig.add_inputs();
ImageConfig* img_conf = input->mutable_image_conf();
img_conf->set_channels(inputs[i].ic);
img_conf->set_img_size_y(inputs[i].ih);
img_conf->set_img_size(inputs[i].iw);
}
}
void testConcatLayer(const std::vector<testImageDesc>& inputs) {
TestConfig dnnConfig;
getMKLDNNConcatConfig(dnnConfig, inputs);
RUN_MKLDNN_TEST_LAYER(dnnConfig, "concat", inputs[0])
}
TEST(MKLDNNLayer, ConcatLayer) {
testConcatLayer({{64, 128, 1, 1}, {64, 32, 1, 1}, {64, 64, 1, 1}});
testConcatLayer({{32, 100, 8, 8}, {32, 10, 8, 8}});
}
void testActivation(std::string actType, const testImageDesc& pm) {
// TODO(TJ): remove me when paddle support elu activation
if (actType == "mkldnn_elu") {
......
......@@ -61,6 +61,18 @@ function(op_library TARGET)
set(pybind_flag 1)
endif()
if ("${TARGET}" STREQUAL "compare_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
endif()
# conv_op contains several operators
if ("${TARGET}" STREQUAL "conv_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
endif()
# pool_op contains several operators
if ("${TARGET}" STREQUAL "pool_op")
set(pybind_flag 1)
......@@ -68,23 +80,23 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif()
if ("${TARGET}" STREQUAL "compare_op")
# pool_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "pool_cudnn_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
endif()
# pool_with_index_op contains several operators
if ("${TARGET}" STREQUAL "pool_with_index_op")
if ("${TARGET}" STREQUAL "logical_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
file(APPEND ${pybind_file} "USE_OP(logical_and);\n")
endif()
# conv_op contains several operators
if ("${TARGET}" STREQUAL "conv_op")
# pool_with_index_op contains several operators
if ("${TARGET}" STREQUAL "pool_with_index_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
endif()
# conv_transpose_op contains several operators
......@@ -93,12 +105,12 @@ function(op_library TARGET)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n")
endif()
# pool_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "pool_cudnn_op")
# conv_transpose_cudnn_op contains two operators
if ("${TARGET}" STREQUAL "conv_transpose_cudnn_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
file(APPEND ${pybind_file} "USE_OP(conv2d_transpose_cudnn);\n")
endif()
# save_restore_op contains several operators
......
......@@ -40,7 +40,8 @@ REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad,
ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv_cudnn,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv_cudnn_grad,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
conv_cudnn_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
......@@ -226,9 +226,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
T alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
// Because beta is zero, it is unnecessary to reset input_grad.
for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc,
......@@ -241,9 +240,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward filter ---------------------
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*filter_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
// Because beta is zero, it is unnecessary to reset filter_grad.
for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
......@@ -261,6 +259,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>);
paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>);
......@@ -225,11 +225,15 @@ REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
......@@ -17,11 +17,15 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);
......@@ -23,7 +23,24 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
framework::OpAttrChecker* op_checker)
: Conv2DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault(std::vector<int>{1, 1});
.SetDefault({1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardward. This size should be carefully setted.")
.SetDefault(4096);
}
};
class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker {
public:
CudnnConv3DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: Conv3DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
......@@ -44,7 +61,22 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
......@@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}
// N, M, H, W
// (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
// N, C, O_h, O_w
// (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
// M, C, K_h, K_w
// (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
cudnnConvolutionDescriptor_t cudnn_conv_desc =
......@@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;
// Input: (N, M, H, W)
// Input: (N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
// Output: (N, C, O_H, O_W)
// Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output_grad->dims()));
// Filter (M, C, K_H, K_W)
// Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
......@@ -200,8 +206,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
T alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
math::set_constant(ctx.device_context(), input_grad, 0);
// Because beta is zero, it is unnecessary to reset input_grad.
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc, output_grad_data,
cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo,
......@@ -212,8 +217,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward filter ---------------------
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
math::set_constant(ctx.device_context(), filter_grad, 0);
// Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
......@@ -231,6 +235,15 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>);
ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);
ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
......@@ -185,17 +185,21 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
REGISTER_OP_CPU_KERNEL(
conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
conv3d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
......@@ -18,14 +18,18 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);
......@@ -24,8 +24,17 @@
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
const framework::Tensor& src, const size_t* index,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, *dst, indexed_src);
}
template <typename Place, typename T>
class GRUKernel : public framework::OpKernel<T> {
......@@ -33,7 +42,6 @@ class GRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias");
......@@ -66,7 +74,18 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
gru_value.prevOutValue = const_cast<T*>(h0_data);
Tensor ordered_h0;
const size_t* order = batch_gate->lod()[2].data();
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true);
gru_value.prevOutValue = ordered_h0.data<T>();
} else {
gru_value.prevOutValue = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; n++) {
......@@ -102,7 +121,6 @@ class GRUGradKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* batch_gate = context.Input<LoDTensor>("BatchGate");
......@@ -135,6 +153,17 @@ class GRUGradKernel : public framework::OpKernel<T> {
zero(dev_ctx, &batch_gate_grad, static_cast<T>(0.0));
zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0));
Tensor ordered_h0, ordered_h0_grad;
const size_t* order = batch_gate->lod()[2].data();
if (h0) {
ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true);
}
if (h0_grad) {
ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
zero(context.device_context(), &ordered_h0_grad, static_cast<T>(0.0));
}
bool is_reverse = context.Attr<bool>("is_reverse");
batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
......@@ -176,14 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) {
gru_value.prevOutValue = const_cast<T*>(h0_data);
if (h0_grad) {
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace());
zero(dev_ctx, h0_grad, static_cast<T>(0.0));
gru_grad.prevOutGrad = h0_grad_data;
} else {
gru_grad.prevOutGrad = nullptr;
}
gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr;
gru_grad.prevOutGrad =
h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
} else {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
......@@ -208,6 +232,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
math::ColwiseSum<Place, T> col_sum;
col_sum(dev_ctx, batch_gate_grad, bias_grad);
}
if (h0 && h0_grad) {
ReorderInitState<Place, T>(context.device_context(), ordered_h0_grad,
order, h0_grad, false);
}
}
void Compute(const framework::ExecutionContext& context) const override {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/logical_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename OpComment>
class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
BinaryLogicalOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment;
AddInput("X",
string::Sprintf("(LoDTensor) Left hand operand of %s operator",
comment.type));
AddInput("Y",
string::Sprintf("(LoDTensor) Right hand operand of %s operator",
comment.type));
AddOutput("Out", string::Sprintf(
"(LoDTensor) n-dim bool tensor. Each element is %s",
comment.equation));
AddComment(string::Sprintf(R"DOC(%s Operator
It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim boolean tensors.
Each element of Out is calculated by %s
)DOC",
comment.type, comment.equation));
}
};
template <typename OpComment>
class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
UnaryLogicalOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment;
AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator",
comment.type));
AddOutput("Out", string::Sprintf(
"(LoDTensor) n-dim bool tensor. Each element is %s",
comment.equation));
AddComment(string::Sprintf(R"DOC(%s Operator
It operates element-wise on X, and returns the Out. X and Out are N-dim boolean tensors.
Each element of Out is calculated by %s
)DOC",
comment.type, comment.equation));
}
};
template <typename OpComment>
class BinaryLogicalOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of %s operator must not be null", comment.type);
PADDLE_ENFORCE(context->HasInput("Y"),
"Input(Y) of %s operator must not be null", comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
"The number of elements in X and Y should be same");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
template <typename OpComment>
class UnaryLogicalOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of %s operator must not be null", comment.type);
auto dim_x = context->GetInputDim("X");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
class LogicalOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_BINARY_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::LogicalOp, \
::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);
#define REGISTER_UNARY_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::LogicalOp, \
::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);
REGISTER_BINARY_LOGICAL_OP(logical_and, "Out = X && Y");
REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CPU,
paddle::operators::LogicalAndFunctor);
REGISTER_BINARY_LOGICAL_OP(logical_or, "Out = X && Y");
REGISTER_BINARY_LOGICAL_KERNEL(logical_or, CPU,
paddle::operators::LogicalOrFunctor);
REGISTER_UNARY_LOGICAL_OP(logical_not, "Out = !X");
REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CPU,
paddle::operators::LogicalNotFunctor);
REGISTER_BINARY_LOGICAL_OP(logical_xor, "Out = (X || Y) && !(X && Y)");
REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CPU,
paddle::operators::LogicalXorFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/logical_op.h"
REGISTER_BINARY_LOGICAL_KERNEL(logical_and, GPU,
paddle::operators::LogicalAndFunctor);
REGISTER_BINARY_LOGICAL_KERNEL(logical_or, GPU,
paddle::operators::LogicalOrFunctor);
REGISTER_UNARY_LOGICAL_KERNEL(logical_not, GPU,
paddle::operators::LogicalNotFunctor);
REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, GPU,
paddle::operators::LogicalXorFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <type_traits>
#include "paddle/framework/op_registry.h"
#include "paddle/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T>
struct LogicalAndFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; }
};
template <typename T>
struct LogicalOrFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; }
};
template <typename T>
struct LogicalNotFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a) const { return !a; }
};
template <typename T>
struct LogicalXorFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
return (a || b) && !(a && b);
}
};
template <typename Place, typename Functor>
class BinaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func;
platform::Transform<Place> trans;
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(),
y->data<T>(), out->mutable_data<bool>(context.GetPlace()),
binary_func);
}
};
template <typename Place, typename Functor>
class UnaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
Functor unary_func;
platform::Transform<Place> trans;
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(),
out->mutable_data<bool>(context.GetPlace()), unary_func);
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##Place, functor<bool>>);
#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##Place, functor<bool>>);
......@@ -20,6 +20,18 @@ REGISTER_OP(pool2d_cudnn, ops::PoolOp, ops::Pool2dOpMaker, pool2d_cudnn_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool2d_cudnn,
ops::PoolKernel<paddle::platform::CPUPlace, float>);
ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool2d_cudnn_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>)
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)
REGISTER_OP(pool3d_cudnn, ops::PoolOp, ops::Pool3dOpMaker, pool3d_cudnn_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool3d_cudnn,
ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool3d_cudnn_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)
......@@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
......@@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
}
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
......@@ -135,8 +147,7 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
if (input_grad) {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::GPUPlace, T> set_zero;
set_zero(ctx.device_context(), input_grad, static_cast<T>(0));
// Because beta is zero, it is unnecessary to reset input_grad.
PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward(
handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,
......@@ -151,5 +162,12 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>);
REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>);
REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>,
ops::PoolCudnnOpKernel<double>);
REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>,
ops::PoolCudnnGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(pool3d_cudnn, ops::PoolCudnnOpKernel<float>,
ops::PoolCudnnOpKernel<double>);
REGISTER_OP_GPU_KERNEL(pool3d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>,
ops::PoolCudnnGradOpKernel<double>);
......@@ -217,14 +217,18 @@ REGISTER_OP(pool2d, ops::PoolOp, ops::Pool2dOpMaker, pool2d_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool2d,
ops::PoolKernel<paddle::platform::CPUPlace, float>);
ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool2d_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>)
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)
REGISTER_OP(pool3d, ops::PoolOp, ops::Pool3dOpMaker, pool3d_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool3d,
ops::PoolKernel<paddle::platform::CPUPlace, float>);
ops::PoolKernel<paddle::platform::CPUPlace, float>,
ops::PoolKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(pool3d_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>);
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
ops::PoolGradKernel<paddle::platform::CPUPlace, double>);
......@@ -17,11 +17,15 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(pool2d,
ops::PoolKernel<paddle::platform::GPUPlace, float>);
ops::PoolKernel<paddle::platform::GPUPlace, float>,
ops::PoolKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(pool2d_grad,
ops::PoolGradKernel<paddle::platform::GPUPlace, float>);
ops::PoolGradKernel<paddle::platform::GPUPlace, float>,
ops::PoolGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(pool3d,
ops::PoolKernel<paddle::platform::GPUPlace, float>);
ops::PoolKernel<paddle::platform::GPUPlace, float>,
ops::PoolKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(pool3d_grad,
ops::PoolGradKernel<paddle::platform::GPUPlace, float>);
ops::PoolGradKernel<paddle::platform::GPUPlace, float>,
ops::PoolGradKernel<paddle::platform::GPUPlace, double>);
......@@ -266,10 +266,12 @@ REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float, int>);
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float, int>,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, double, int>);
REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float, int>)
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float, int>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, double, int>)
REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad,
......@@ -277,7 +279,9 @@ REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float, int>);
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float, int>,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, double, int>);
REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float, int>)
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float, int>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, double, int>)
......@@ -18,14 +18,18 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float, int>);
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float, int>,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, double, int>);
REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float, int>)
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float, int>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, double, int>)
REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float, int>);
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float, int>,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, double, int>);
REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float, int>)
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float, int>,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, double, int>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_slice_op.h"
namespace paddle {
namespace operators {
class SequenceSliceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceSliceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Offset"),
"Input(Offset) of SequenceSliceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Length"),
"Input(Length) of SequenceSliceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceSliceOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
auto offset_dim = ctx->GetInputDim("Offset");
auto length_dim = ctx->GetInputDim("Length");
PADDLE_ENFORCE_EQ(
offset_dim.size(), 2UL,
"Only support one level sequence now, The rank of offset must be 2.");
PADDLE_ENFORCE_EQ(
length_dim.size(), 2UL,
"Only support one level sequence now, The rank of Length must be 2.");
// Initialize the output's dims to maximum,
// and re-set to real dims by the value of Offset and Length at kernel
ctx->SetOutputDim("Out", input_dims);
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class SequenceSliceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"The gradient of X should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSliceOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor), "
"the input of SequenceSliceOp.");
AddInput("Offset",
"(Tensor), "
"a vector<int> to describe the offset of every input sequence for "
"sub sequence item.");
AddInput("Length",
"(Tensor), "
"a vector<int> to describe the length of every input sequence for "
"sub sequence item.");
AddOutput("Out",
"(LoDTensor), the output of SequenceSliceOp.");
AddComment(R"DOC(
Sequence slice operator
The operator crops a subsequence from given sequence with given start offset and subsequence length.
It only supports sequence (LoD Tensor with level number is 1).
- Case:
X = [[a1, a2;
b1, b2;
c1, c2]
[d1, d2;
e1, e2]]
LoD(X) = {{0, 3, 5}}; Dims(X) = (5, 2)
Offset = [[0], [1]]; Length = [[2], [1]]
Out = [[a1, a2;
b1, b2]
[e1, e2]]
LoD(Out) = {{0, 2, 3}}; Dims(Out) = (3, 2)
NOTE: The first dimension size of input, the size of offset and Length, should be equal. The offset start from 0.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_slice, ops::SequenceSliceOp, ops::SequenceSliceOpMaker,
sequence_slice_grad, ops::SequenceSliceGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_slice,
ops::SequenceSliceOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_slice_grad,
ops::SequenceSliceGradOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_slice_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_slice,
ops::SequenceSliceOpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sequence_slice_grad,
ops::SequenceSliceGradOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
template <typename T>
inline LoD SequenceSliceLoD(const T& in, const int64_t* offset_data,
const int64_t* length_data) {
auto out_lod = in.lod();
size_t lod_offset = 0;
auto n = in.lod()[0].size() - 1;
out_lod[0][0] = 0;
for (size_t i = 0; i < n; ++i) {
lod_offset += length_data[i];
out_lod[0][i+1] = lod_offset;
}
return out_lod;
}
template <typename Place, typename T>
class SequenceSliceOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<LoDTensor>("X");
auto* offset = ctx.Input<Tensor>("Offset");
auto* length = ctx.Input<Tensor>("Length");
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = in->lod();
auto n = lod[0].size() - 1;
PADDLE_ENFORCE_EQ(lod.size(), 1UL,
"Only support one level sequence now.");
PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(length->dims()[0]),
"The size of input-sequence and length-array should be the same")
PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(offset->dims()[0]),
"The size of input-sequence and offset-array should be the same")
const int64_t* offset_data = offset->data<int64_t>();
const int64_t* length_data = length->data<int64_t>();
framework::Tensor offset_cpu;
framework::Tensor length_cpu;
if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
length_data = length_cpu.data<int64_t>();
}
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_LT(0, offset_data[i],
"The offset[%d] must greater than zero.", i)
PADDLE_ENFORCE_LT(0, length_data[i],
"The length[%d] must greater than zero.", i)
PADDLE_ENFORCE_LT(
lod[0][i] + offset_data[i] + length_data[i],
lod[0][i + 1],
"The target tensor's length overflow.")
}
out->mutable_data<T>(ctx.GetPlace());
auto out_lod = SequenceSliceLoD(*in, offset_data, length_data);
auto out_dims = in->dims();
out_dims[0] = out_lod[0][out_lod[0].size() - 1];
out->Resize(out_dims);
out->set_lod(out_lod);
auto in_stride = framework::stride(in->dims());
auto out_stride = framework::stride(out->dims());
size_t out_offset = 0;
for (size_t i = 0; i < n; ++i) {
Tensor in_t =
in->Slice(static_cast<int>(lod[0][i] + offset_data[i]),
static_cast<int>(lod[0][i] + offset_data[i] +
length_data[i]));
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(),
in_stride, in_t.dims(), out_stride,
out->data<T>() + out_offset);
out_offset += length_data[i] * in_stride[0];
}
}
};
template <typename Place, typename T>
class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<LoDTensor>("X");
auto* offset = ctx.Input<Tensor>("Offset");
auto* length = ctx.Input<Tensor>("Length");
auto* out_grad =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
const int64_t* offset_data = offset->data<int64_t>();
const int64_t* length_data = length->data<int64_t>();
framework::Tensor offset_cpu;
framework::Tensor length_cpu;
if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
length_data = length_cpu.data<int64_t>();
}
auto lod = in->lod();
auto out_lod = out_grad->lod();
if (x_grad) {
x_grad->mutable_data<T>(ctx.GetPlace());
x_grad->set_lod(in->lod());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
auto out_grad_stride = framework::stride(out_grad->dims());
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
Tensor out_grad_t =
out_grad->Slice(static_cast<int>(out_lod[0][i]),
static_cast<int>(out_lod[0][i + 1]));
auto out_grad_stride = framework::stride(out_grad_t.dims());
auto x_grad_stride = framework::stride(x_grad->dims());
Tensor x_grad_t = x_grad->Slice(
static_cast<int>(lod[0][i] + offset_data[i]),
static_cast<int>(lod[0][i] + offset_data[i] + length_data[i]));
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -30,7 +30,7 @@ void sgdUpdateCpu(real learningRate,
const real* grad,
real* momentumVec) {
decayRate *= learningRate;
#ifdef PADDLE_USE_MKLDNN
#ifdef PADDLE_USE_MKLML
#pragma omp parallel for
#endif
for (size_t i = 0; i < size; ++i) {
......
......@@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
} \
} while (false)
enum class DataLayout {
enum class DataLayout { // Not use
kNHWC,
kNCHW,
kNCDHW,
kNCHW_VECT_C,
};
......@@ -107,12 +108,15 @@ class CudnnDataType<double> {
}
};
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
inline cudnnTensorFormat_t GetCudnnTensorFormat(
const DataLayout& order) { // Not use
switch (order) {
case DataLayout::kNHWC:
return CUDNN_TENSOR_NHWC;
case DataLayout::kNCHW:
return CUDNN_TENSOR_NCHW;
case DataLayout::kNCDHW:
return CUDNN_TENSOR_NCHW; // TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW
default:
PADDLE_THROW("Unknown cudnn equivalent for order");
}
......@@ -139,7 +143,7 @@ class ScopedTensorDescriptor {
strides[i] = dims[i + 1] * strides[i + 1];
}
// Update tensor descriptor dims setting if groups > 1
// FIXME(typhoonzero): Assume using NCHW order
// FIXME(typhoonzero): Assume using NCHW or NCDHW order
std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy
if (groups > 1) {
dims_with_group[1] = dims_with_group[1] / groups;
......@@ -176,9 +180,10 @@ class ScopedFilterDescriptor {
const cudnnDataType_t type,
const std::vector<int>& kernel,
const int groups = 1) {
// filter layout: MCHW, where M is the number of
// filter layout: MCHW(MCDHW), where M is the number of
// output image channels, C is the number of input image channels,
// H and W is height and width of filter.
// D is the depth of the filter, H is the height of the filter, and W is the
// width of the filter.
std::vector<int> kernel_with_group(kernel.begin(), kernel.end());
if (groups > 1) {
// M /= groups
......@@ -219,13 +224,15 @@ class ScopedConvolutionDescriptor {
PADDLE_ENFORCE_EQ(pads.size(), strides.size());
PADDLE_ENFORCE_EQ(pads.size(), dilations.size());
#if CUDNN_VERSION < 6000
#if !CUDNN_VERSION_MIN(6, 0, 0)
// cudnn v5 does not support dilation conv, the argument is called upscale
// instead of dilations and it is must be one.
for (size_t i = 0; i < dilations.size(); ++i) {
PADDLE_ENFORCE_EQ(
dilations[i], 1,
"Dilations conv is not supported in this cuDNN version");
"Dilations conv is not supported in this cuDNN version(%d.%d.%d).",
CUDNN_VERSION / 1000, CUDNN_VERSION % 1000 / 100,
CUDNN_VERSION % 100);
}
#endif
......
......@@ -38,6 +38,26 @@ TEST(CudnnHelper, ScopedTensorDescriptor) {
EXPECT_EQ(strides[2], 6);
EXPECT_EQ(strides[1], 36);
EXPECT_EQ(strides[0], 144);
// test tensor5d: ScopedTensorDescriptor
ScopedTensorDescriptor tensor5d_desc;
std::vector<int> shape_5d = {2, 4, 6, 6, 6};
auto desc_5d = tensor5d_desc.descriptor<float>(DataLayout::kNCDHW, shape_5d);
std::vector<int> dims_5d(5);
std::vector<int> strides_5d(5);
paddle::platform::dynload::cudnnGetTensorNdDescriptor(
desc_5d, 5, &type, &nd, dims_5d.data(), strides_5d.data());
EXPECT_EQ(nd, 5);
for (size_t i = 0; i < dims_5d.size(); ++i) {
EXPECT_EQ(dims_5d[i], shape_5d[i]);
}
EXPECT_EQ(strides_5d[4], 1);
EXPECT_EQ(strides_5d[3], 6);
EXPECT_EQ(strides_5d[2], 36);
EXPECT_EQ(strides_5d[1], 216);
EXPECT_EQ(strides_5d[0], 864);
}
TEST(CudnnHelper, ScopedFilterDescriptor) {
......@@ -60,6 +80,20 @@ TEST(CudnnHelper, ScopedFilterDescriptor) {
for (size_t i = 0; i < shape.size(); ++i) {
EXPECT_EQ(kernel[i], shape[i]);
}
ScopedFilterDescriptor filter_desc_4d;
std::vector<int> shape_4d = {2, 3, 3, 3};
auto desc_4d = filter_desc.descriptor<float>(DataLayout::kNCDHW, shape_4d);
std::vector<int> kernel_4d(4);
paddle::platform::dynload::cudnnGetFilterNdDescriptor(
desc_4d, 4, &type, &format, &nd, kernel_4d.data());
EXPECT_EQ(GetCudnnTensorFormat(DataLayout::kNCHW), format);
EXPECT_EQ(nd, 4);
for (size_t i = 0; i < shape_4d.size(); ++i) {
EXPECT_EQ(kernel_4d[i], shape_4d[i]);
}
}
TEST(CudnnHelper, ScopedConvolutionDescriptor) {
......
......@@ -1826,7 +1826,7 @@ class FCLayer(LayerBase):
self.layer_type = 'mkldnn_fc'
config_assert(
len(inputs) == 1,
"MkldnnFCLayer support one and only one input!")
"MKLDNNFCLayer support one and only one input!")
super(FCLayer, self).__init__(
name, self.layer_type, size, inputs=inputs, **xargs)
for input_index in xrange(len(self.inputs)):
......@@ -1837,7 +1837,7 @@ class FCLayer(LayerBase):
sparse = format == "csr" or format == "csc"
if use_mkldnn:
config_assert(not sparse,
"MkldnnFCLayer do not support sparse format yet")
"MKLDNNFCLayer do not support sparse format yet")
if use_mkldnn_wgt:
dims = [self.config.size, input_layer.size]
if sparse:
......@@ -1853,7 +1853,7 @@ class FCLayer(LayerBase):
@config_layer('mkldnn_fc')
class MkldnnFcLayer(FCLayer):
class MKLDNNFcLayer(FCLayer):
layer_type = 'mkldnn_fc'
......@@ -3209,6 +3209,18 @@ class SubNestedSequenceLayer(LayerBase):
self.set_layer_size(size)
@config_layer('dot_prod')
class DotProdLayer(LayerBase):
def __init__(self, name, inputs, device=None):
super(DotProdLayer, self).__init__(
name, 'dot_prod', 0, inputs, device=device)
config_assert(len(inputs) == 2, 'DotProdLayer must have 2 inputs.')
config_assert(
self.get_input_layer(0).size == self.get_input_layer(1).size,
"Two inputs should have the same size.")
self.set_layer_size(1)
@config_layer('out_prod')
class OuterProdLayer(LayerBase):
def __init__(self, name, inputs, device=None):
......@@ -3330,6 +3342,20 @@ class RowL2NormLayer(LayerBase):
self.set_layer_size(input_layer.size)
@config_layer('cos')
class CosSimLayer(LayerBase):
def __init__(self, name, inputs, cos_scale=1, device=None):
super(CosSimLayer, self).__init__(
name, 'cos', 1, inputs=inputs, device=device)
config_assert(
len(self.inputs) == 2,
'The CosSimLayer expects two and only two inputs.')
config_assert(
self.get_input_layer(0).size == self.get_input_layer(1).size,
'The two inputs of CosSimLayer must have the same dimensionality.')
self.config.cos_scale = cos_scale
@config_layer('cos_vm')
class CosSimVecMatLayer(LayerBase):
def __init__(self, name, size, inputs, cos_scale=1.0, device=None):
......@@ -3337,10 +3363,24 @@ class CosSimVecMatLayer(LayerBase):
name, 'cos_vm', size, inputs=inputs, device=device)
self.config.cos_scale = cos_scale
config_assert(
len(self.inputs) == 2, 'CosSimVecMatLayer must have 2 inputs')
len(self.inputs) == 2, 'The CosSimVecMatLayer must have 2 inputs.')
config_assert(
size * self.get_input_layer(0).size == self.get_input_layer(1).size,
'Wrong input size for CosSimVecMatLayer')
'Wrong input size for CosSimVecMatLayer.')
@config_layer('l2_distance')
class L2DistanceLayer(LayerBase):
def __init__(self, name, inputs, device=None):
super(L2DistanceLayer, self).__init__(
name, 'l2_distance', 1, inputs=inputs, device=device)
config_assert(
len(self.inputs) == 2, ('The L2DistanceLayer must have '
'and only have 2 inputs.'))
config_assert(
self.get_input_layer(0).size == self.get_input_layer(1).size,
('Two inputs of the L2DistanceLayer must have '
'the same dimensionality.'))
@config_layer('sampling_id')
......@@ -3384,18 +3424,6 @@ class AverageLayer(LayerBase):
self.create_bias_parameter(bias, self.config.size)
@config_layer('cos')
class CosSimLayer(LayerBase):
def __init__(self, name, inputs, cos_scale=1, device=None):
super(CosSimLayer, self).__init__(
name, 'cos', 1, inputs=inputs, device=device)
config_assert(len(self.inputs) == 2, 'CosSimLayer must have 2 inputs')
config_assert(
self.get_input_layer(0).size == self.get_input_layer(1).size,
'inputs of CosSimLayer must have same dim')
self.config.cos_scale = cos_scale
@config_layer('tensor')
class TensorLayer(LayerBase):
def __init__(self, name, size, inputs, bias=True, **xargs):
......@@ -3506,11 +3534,17 @@ def ExpressionLayer(name, inputs, **xargs):
@config_layer('concat')
class ConcatenateLayer(LayerBase):
layer_type = 'concat'
def __init__(self, name, inputs, bias=False, **xargs):
config_assert(inputs, 'inputs cannot be empty')
config_assert(not bias, 'ConcatenateLayer cannot support bias.')
use_mkldnn = bool(int(g_command_config_args.get("use_mkldnn", 0)))
if self.layer_type == "mkldnn_concat":
config_assert(use_mkldnn, "mkldnn_concat only support MKLDNN")
self.layer_type = 'mkldnn_concat' if use_mkldnn else 'concat'
super(ConcatenateLayer, self).__init__(
name, 'concat', 0, inputs=inputs, **xargs)
name, self.layer_type, 0, inputs=inputs, **xargs)
size = 0
for input_index in xrange(len(self.inputs)):
assert self.get_input_layer(0).height == self.get_input_layer(
......@@ -3530,6 +3564,11 @@ class ConcatenateLayer(LayerBase):
self.set_layer_size(size)
@config_layer('mkldnn_concat')
class MKLDNNConcatLayer(ConcatenateLayer):
layer_type = 'mkldnn_concat'
# like concat layer, but each input layer was processed by a Projection.
@config_layer('concat2')
class ConcatenateLayer2(LayerBase):
......
......@@ -51,6 +51,7 @@ __all__ = [
'last_seq',
'first_seq',
'cos_sim',
'l2_distance_layer',
'hsigmoid',
'conv_projection',
'square_error_cost',
......@@ -115,6 +116,7 @@ __all__ = [
'huber_classification_cost',
'block_expand_layer',
'maxout_layer',
'dot_prod_layer',
'out_prod_layer',
'printer_layer',
'print_layer',
......@@ -167,6 +169,7 @@ class LayerType(object):
COST = 'cost'
COSINE_SIM_VEC = 'cos_vm'
COSINE_SIM = 'cos'
L2_DISTANCE = 'l2_distance'
HSIGMOID = 'hsigmoid'
CONV_LAYER = 'conv'
CONVTRANS_LAYER = 'convt'
......@@ -197,6 +200,7 @@ class LayerType(object):
SCALING_LAYER = 'scaling'
TRANS_LAYER = 'trans'
ROTATE_LAYER = 'rotate'
DOT_PROD_LAYER = 'dot_prod'
OUT_PROD_LAYER = 'out_prod'
FEATURE_MAP_EXPAND_LAYER = 'featmap_expand'
......@@ -2332,6 +2336,51 @@ def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None):
return LayerOutput(name, LayerType.COSINE_SIM, parents=[a, b], size=size)
@wrap_name_default()
@layer_support()
def l2_distance_layer(x, y, name=None, layer_attr=None):
"""
This layer calculates and returns the Euclidean distance between two input
vectors x and y. The equation is as follows:
.. math::
l2_distance(\\mathbf{x}, \\mathbf{y}) = \\sqrt{\\sum_{i=1}^D(x_i - y_i)}
The output size of this layer is fixed to be 1. Note that the above
computation is for one sample. Multiple samples are processed in one batch.
The example usage is:
.. code-block:: python
l2_sim = l2_distance(x=layer1, y=layer2)
:param name: The name of this layer. It is optional.
:type name: basestring
:param x: The first input x for this layer, whose output is a matrix with
dimensionality N x D. N is the sample number in a mini-batch.
D is the dimensionality of x's output.
:type x: LayerOutput
:param y: The second input y for this layer, whose output is a matrix with
dimensionality N x D. N is the sample number in a mini-batch.
D is the dimensionality of y's output.
:type y: LayerOutput
:param layer_attr: The extra layer attributes, for example, drop rate.
See ExtraLayerAttribute for more details.
:type layer_attr: ExtraLayerAttribute
:return: The returned LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(x, LayerOutput) and isinstance(y, LayerOutput)
Layer(
name=name,
type=LayerType.L2_DISTANCE,
inputs=[x.name, y.name],
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name, LayerType.L2_DISTANCE, parents=[x, y], size=1)
@wrap_name_default()
@wrap_bias_attr_default(has_bias=True)
@wrap_param_attr_default()
......@@ -3871,7 +3920,7 @@ def recurrent_layer(input,
:type input: LayerOutput
:param act: Activation type. TanhActivation is the default activation.
:type act: BaseActivation
:param bias_attr: The parameter attribute for bias. If this parameter is set to
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute,
no bias is defined. If the parameter is set to True,
the bias is initialized to zero.
......@@ -4140,6 +4189,45 @@ def maxid_layer(input, name=None, layer_attr=None):
size=l.config.size)
@wrap_name_default()
def dot_prod_layer(input1, input2, name=None, layer_attr=None):
"""
A layer for computing the dot product of two vectors.
The example usage is:
.. code-block:: python
dot_prod = dot_prod_layer(input1=vec1, input2=vec2)
:param name: The name of this layer. It is optional.
:type name: basestring
:param input1: The first input layer.
:type input: LayerOutput
:param input2: The second input layer.
:type input2: LayerOutput
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input1, LayerOutput)
assert isinstance(input2, LayerOutput)
assert input1.size == input2.size, ("Two inputs should have the same size.")
l = Layer(
name=name,
type=LayerType.DOT_PROD_LAYER,
inputs=[input1.name, input2.name],
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name=name,
layer_type=LayerType.DOT_PROD_LAYER,
parents=[input1, input2],
size=l.config.size)
@wrap_name_default()
def out_prod_layer(input1, input2, name=None, layer_attr=None):
"""
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from activations import LinearActivation, ReluActivation, SoftmaxActivation, \
IdentityActivation, TanhActivation, SequenceSoftmaxActivation
......@@ -26,9 +26,9 @@ __all__ = [
'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool",
"img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg',
'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru',
'simple_attention', 'dot_product_attention', 'simple_gru2',
'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'inputs',
'outputs'
'simple_attention', 'dot_product_attention', 'multi_head_attention',
'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm',
'inputs', 'outputs'
]
######################################################
......@@ -1476,10 +1476,8 @@ def dot_product_attention(encoded_sequence,
expand_as=encoded_sequence,
name='%s_expand' % name)
m = linear_comb_layer(
weights=expanded,
vectors=encoded_sequence,
name='%s_dot-product' % name)
m = dot_prod_layer(
input1=expanded, input2=encoded_sequence, name='%s_dot-product' % name)
attention_weight = fc_layer(
input=m,
......@@ -1498,6 +1496,134 @@ def dot_product_attention(encoded_sequence,
input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name)
@wrap_name_default()
def multi_head_attention(query,
key,
value,
key_proj_size,
value_proj_size,
head_num,
attention_type,
softmax_param_attr=None,
name=None):
"""
Calculate and return a context vector with dot-product attention mechanism.
The dimension of the context vector equals to value_proj_size * head_num.
Please refer to **Attention Is All You Need** for more details. The link is
as follows:
https://arxiv.org/abs/1706.03762.
The example usage is:
.. code-block:: python
context = multi_head_attention(query=decoder_state,
key=enc_seq,
value=enc_seq,
key_proj_size=64,
value_pro_size=64,
head_num=8,
attention_type='dot-product attention')
:param name: A prefix attached to the name of each layer that defined inside
the multi_head_attention.
:type name: basestring
:param softmax_param_attr: The parameter attribute of sequence softmax
that is used to produce attention weight.
:type softmax_param_attr: ParameterAttribute
:param query: query is used to calculate attention weights over values at current step.
:type query: LayerOutput
:param key: key is used to calculate the attention weight of the corresponding value.
:type key: LayerOutput
:param value: value is the sequence to be attended.
:type value: LayerOutput
:param key_proj_size: The dimension of the linear projection performed on key and query.
:type key_proj_size: int
:param value_proj_size: The dimension of the linear projection performed on value.
:type value_proj_size: int
:param head_num: The number of attention heads.
:type head_num: int
:param attention_type: The type of the attention mechanism used in each attention
heads. Now, we only support scaled dot-product attention and
additive attention.
:type attention_type: basestring
:return: The context vector.
:rtype: LayerOutput
"""
assert attention_type in ['dot-product attention', 'additive attention']
with mixed_layer(
size=key_proj_size * head_num,
name='%s_query_proj' % name) as query_proj:
query_proj += full_matrix_projection(query)
query_proj = expand_layer(input=query_proj, expand_as=key)
with mixed_layer(
size=key_proj_size * head_num,
name='%s_key_proj' % name) as key_proj:
key_proj += full_matrix_projection(key)
with mixed_layer(
size=value_proj_size * head_num,
name='%s_value_proj' % name) as value_proj:
value_proj += full_matrix_projection(value)
head_list = []
for i in range(head_num):
with mixed_layer(size=key_proj_size) as sub_query_proj:
sub_query_proj += identity_projection(
query_proj, offset=key_proj_size * i, size=key_proj_size)
with mixed_layer(size=key_proj_size) as sub_key_proj:
sub_key_proj += identity_projection(
key_proj, offset=key_proj_size * i, size=key_proj_size)
with mixed_layer(size=value_proj_size) as sub_value_proj:
sub_value_proj += identity_projection(
value_proj, offset=value_proj_size * i, size=value_proj_size)
if attention_type == 'dot-product attention':
m = dot_prod_layer(
input1=sub_query_proj,
input2=sub_key_proj,
name='%s_dot-product_%d' % (name, i))
m = slope_intercept_layer(
input=m,
slope=math.sqrt(1.0 / key_proj_size),
name='%s_dot-product_scaling_%d' % (name, i))
else:
with mixed_layer(
size=key_proj_size,
act=TanhActivation(),
name='%s_combine_%d' % (name, i)) as m:
m += identity_projection(sub_query_proj)
m += identity_projection(sub_key_proj)
attention_weight = fc_layer(
input=m,
size=1,
act=SequenceSoftmaxActivation(),
param_attr=softmax_param_attr,
name="%s_softmax_%d" % (name, i),
bias_attr=False)
scaled = scaling_layer(
weight=attention_weight,
input=sub_value_proj,
name='%s_scaling_%d' % (name, i))
head = pooling_layer(
input=scaled,
pooling_type=SumPooling(),
name="%s_pooling_%d" % (name, i))
head_list.append(head)
attended = concat_layer(head_list)
return attended
def inputs(layers, *args):
"""
Declare the inputs of network. The order of input should be as same as
......
......@@ -10,6 +10,7 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
test_seq_slice_layer test_cross_entropy_over_beam test_roi_pool_layer test_pooling3D_layer
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_scale_sub_region_layer)
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer
test_scale_sub_region_layer test_dot_prod_layer test_l2_distance_layer)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "vector1"
type: "data"
size: 10
active_type: ""
}
layers {
name: "vector2"
type: "data"
size: 10
active_type: ""
}
layers {
name: "__dot_prod_layer_0__"
type: "dot_prod"
size: 1
active_type: ""
inputs {
input_layer_name: "vector1"
}
inputs {
input_layer_name: "vector2"
}
}
input_layer_names: "vector1"
input_layer_names: "vector2"
output_layer_names: "__dot_prod_layer_0__"
sub_models {
name: "root"
layer_names: "vector1"
layer_names: "vector2"
layer_names: "__dot_prod_layer_0__"
input_layer_names: "vector1"
input_layer_names: "vector2"
output_layer_names: "__dot_prod_layer_0__"
is_recurrent_layer_group: false
}
type: "nn"
layers {
name: "x"
type: "data"
size: 128
active_type: ""
}
layers {
name: "y"
type: "data"
size: 128
active_type: ""
}
layers {
name: "__l2_distance_layer_0__"
type: "l2_distance"
size: 1
active_type: ""
inputs {
input_layer_name: "x"
}
inputs {
input_layer_name: "y"
}
}
input_layer_names: "x"
input_layer_names: "y"
output_layer_names: "__l2_distance_layer_0__"
sub_models {
name: "root"
layer_names: "x"
layer_names: "y"
layer_names: "__l2_distance_layer_0__"
input_layer_names: "x"
input_layer_names: "y"
output_layer_names: "__l2_distance_layer_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
vec1 = data_layer(name='vector1', size=10)
vec2 = data_layer(name='vector2', size=10)
dot_product = dot_prod_layer(input1=vec1, input2=vec2)
outputs(dot_product)
from paddle.trainer_config_helpers import *
outputs(
l2_distance_layer(
x=data_layer(
name='x', size=128), y=data_layer(
name='y', size=128)))
......@@ -4,7 +4,10 @@ import collections
import numpy as np
import copy
__all__ = ['Block', 'Variable', 'Program', 'Operator', 'default_startup_program', 'default_main_program']
__all__ = [
'Block', 'Variable', 'Program', 'Operator', 'default_startup_program',
'default_main_program'
]
def unique_name(prefix):
......@@ -232,17 +235,17 @@ class Operator(object):
in_proto.name)
if found:
in_argus = inputs[in_proto.name]
if not isinstance(in_argus, list):
in_argus = [in_argus]
if not in_proto.duplicable and len(in_argus) > 1:
in_args = inputs[in_proto.name]
if not isinstance(in_args, list):
in_args = [in_args]
if not in_proto.duplicable and len(in_args) > 1:
raise ValueError(
"Input %s expects only one input, but %d are given."
% (in_proto.name, len(in_argus)))
in_argu_names = []
for argu in in_argus:
in_argu_names.append(argu.name)
self.desc.set_input(in_proto.name, in_argu_names)
% (in_proto.name, len(in_args)))
in_arg_names = []
for arg in in_args:
in_arg_names.append(arg.name)
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
......@@ -260,18 +263,18 @@ class Operator(object):
str(e) for e in given)))
for out_proto in proto.outputs:
out_argus = outputs[out_proto.name]
if not isinstance(out_argus, list):
out_argus = [out_argus]
if not out_proto.duplicable and len(out_argus) > 1:
out_args = outputs[out_proto.name]
if not isinstance(out_args, list):
out_args = [out_args]
if not out_proto.duplicable and len(out_args) > 1:
raise ValueError(
"Output %s expects only one output, but %d are given." %
(out_proto.name, len(out_argus)))
out_argu_names = []
for argu in out_argus:
out_argu_names.append(argu.name)
argu.op = self
self.desc.set_output(out_proto.name, out_argu_names)
(out_proto.name, len(out_args)))
out_arg_names = []
for arg in out_args:
out_arg_names.append(arg.name)
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
if attrs is not None:
if not isinstance(attrs, dict):
......@@ -582,8 +585,10 @@ class Parameter(Variable):
g_main_program = Program()
g_startup_program = Program()
def default_startup_program():
return g_startup_program
def default_main_program():
return g_main_program
......@@ -108,5 +108,11 @@ class TestWithStride(TestConv3dTransposeOp):
self.filter_size = [f_c, 6, 3, 3, 3]
# ------------ test_cudnn ------------
class TestCudnn(TestConv3dTransposeOp):
def init_op_type(self):
self.op_type = "conv3d_transpose_cudnn"
if __name__ == '__main__':
unittest.main()
......@@ -6,7 +6,8 @@ from test_lstm_op import identity, sigmoid, tanh, relu
class TestGRUOp(OpTest):
batch_size = 9
lod = [[0, 2, 6, 9]]
batch_size = lod[0][-1]
frame_size = 5
activate = {
'identity': identity,
......@@ -35,7 +36,7 @@ class TestGRUOp(OpTest):
seq_starts[sorted_seqs[i]] + batch_idx)
idx_in_seq.append(idx)
idx_in_seq_list.append(idx_in_seq)
return idx_in_seq_list
return idx_in_seq_list, sorted_seqs
def gru_step(self, x, h_p, w, b):
batch_size = x.shape[0]
......@@ -66,8 +67,8 @@ class TestGRUOp(OpTest):
batch_hidden = self.outputs['BatchHidden']
hidden = self.outputs['Hidden']
idx_in_seq_list = self.idx_in_seq_list
h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros(
(len(idx_in_seq_list[0]), self.frame_size))
h_p = self.inputs['H0'][self.sorted_seqs] if self.inputs.has_key(
'H0') else np.zeros((len(idx_in_seq_list[0]), self.frame_size))
num_batch = len(idx_in_seq_list)
end_idx = 0
for batch_idx in range(num_batch):
......@@ -84,8 +85,9 @@ class TestGRUOp(OpTest):
return batch_gate, batch_reset_hidden_prev, hidden
def set_data(self):
lod = [[0, 2, 6, self.batch_size]]
self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse)
lod = self.lod
self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
lod, self.is_reverse)
batch_size = self.batch_size
frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
......@@ -146,7 +148,7 @@ class TestGRUOpReverse(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.attrs = {
'activation': 'identity',
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
......
import op_test
import unittest
import numpy as np
def create_test_class(op_type, callback, binary_op=True):
class Cls(op_test.OpTest):
def setUp(self):
a = np.random.choice(a=[True, False], size=(10, 7)).astype(bool)
if binary_op:
b = np.random.choice(a=[True, False], size=(10, 7)).astype(bool)
c = callback(a, b)
else:
c = callback(a)
self.outputs = {'Out': c}
self.op_type = op_type
if binary_op:
self.inputs = {'X': a, 'Y': b}
else:
self.inputs = {'X': a}
def test_output(self):
self.check_output()
Cls.__name__ = op_type
globals()[op_type] = Cls
create_test_class('logical_and', lambda _a, _b: np.logical_and(_a, _b))
create_test_class('logical_or', lambda _a, _b: np.logical_or(_a, _b))
create_test_class('logical_not', lambda _a: np.logical_not(_a), False)
create_test_class('logical_xor', lambda _a, _b: np.logical_xor(_a, _b))
if __name__ == '__main__':
unittest.main()
......@@ -16,14 +16,18 @@ class TestOptimizer(unittest.TestCase):
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.01)
opts = sgd_optimizer.minimize(mul_out, init_program)
opts = sgd_optimizer.minimize(mean_out, init_program)
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "sgd")
......@@ -44,12 +48,16 @@ class TestOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
global_step = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="step")
learning_rate = 0.01
sgd_optimizer = optimizer.SGDOptimizer(
learning_rate=learning_rate, global_step=global_step)
opts = sgd_optimizer.minimize(mul_out, init_program)
opts = sgd_optimizer.minimize(mean_out, init_program)
self.assertEqual(len(opts), 2)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "sgd")
......@@ -90,7 +98,11 @@ class TestMomentumOptimizer(unittest.TestCase):
learning_rate = 0.01
momentum_optimizer = self.MockMomentum(
learning_rate=learning_rate, momentum=0.2)
params_grads = append_backward_ops(mul_out)
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(
......@@ -132,10 +144,14 @@ class TestMomentumOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
learning_rate = 0.01
momentum_optimizer = self.MockMomentum(
learning_rate=learning_rate, momentum=0.2, use_nesterov=True)
params_grads = append_backward_ops(mul_out)
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(
......@@ -186,10 +202,14 @@ class TestAdagradOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
learning_rate = 0.01
adagrad_optimizer = self.MockAdagrad(
learning_rate=learning_rate, epsilon=1.0e-6)
params_grads = append_backward_ops(mul_out)
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0)
opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out,
......@@ -242,10 +262,14 @@ class TestAdamOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
learning_rate = 0.01
adam_optimizer = self.MockAdam(
learning_rate=learning_rate, beta1=0.9, beta2=0.999)
params_grads = append_backward_ops(mul_out)
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adam_optimizer.get_accumulators()), 0)
opts = adam_optimizer.create_optimization_pass(params_grads, mul_out,
......@@ -300,10 +324,14 @@ class TestAdamaxOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
learning_rate = 0.01
adamax_optimizer = self.MockAdamax(
learning_rate=learning_rate, beta1=0.9, beta2=0.999)
params_grads = append_backward_ops(mul_out)
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adamax_optimizer.get_accumulators()), 0)
opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out,
......@@ -355,10 +383,14 @@ class TestDecayedAdagradOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
learning_rate = 0.01
decayed_adagrad_optimizer = self.MockDecayedAdagrad(
learning_rate=learning_rate, decay=0.95, epsilon=1.0e-6)
params_grads = append_backward_ops(mul_out)
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0)
opts = decayed_adagrad_optimizer.create_optimization_pass(
......
......@@ -3,8 +3,7 @@ import numpy as np
from op_test import OpTest
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
......@@ -23,8 +22,7 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
return out
def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
def avg_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
......@@ -47,6 +45,7 @@ def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
class TestPool2d_Op(OpTest):
def setUp(self):
self.init_test_case()
self.init_global_pool()
self.init_op_type()
self.init_pool_type()
if self.global_pool:
......@@ -75,8 +74,6 @@ class TestPool2d_Op(OpTest):
self.check_grad(set(['X']), 'Out', max_relative_error=0.07)
def init_test_case(self):
self.global_pool = True
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -87,12 +84,14 @@ class TestPool2d_Op(OpTest):
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = True
class TestCase1(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -103,12 +102,14 @@ class TestCase1(TestPool2d_Op):
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = False
class TestCase2(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -119,152 +120,69 @@ class TestCase2(TestPool2d_Op):
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = False
class TestCase3(TestPool2d_Op):
def init_test_case(self):
self.global_pool = True
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
class TestCase3(TestPool2d_Op):
def init_op_type(self):
self.op_type = "pool2d"
def init_pool_type(self):
self.pool_type = "max"
class TestCase4(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
class TestCase4(TestCase1):
def init_op_type(self):
self.op_type = "pool2d"
def init_pool_type(self):
self.pool_type = "max"
class TestCase5(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
class TestCase5(TestCase2):
def init_op_type(self):
self.op_type = "pool2d"
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
#--------------------test pool2d_cudnn--------------------
class TestCaseCudnn1(TestPool2d_Op):
def init_test_case(self):
self.global_pool = True
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
class TestCudnnCase1(TestPool2d_Op):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "avg"
class TestCaseCudnn2(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
class TestCudnnCase2(TestCase1):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "avg"
class TestCaseCudnn3(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
class TestCudnnCase3(TestCase2):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "avg"
class TestCaseCudnn4(TestPool2d_Op):
def init_test_case(self):
self.global_pool = True
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
class TestCudnnCase4(TestCase3):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "max"
class TestCaseCudnn5(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
class TestCudnnCase5(TestCase4):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "max"
class TestCaseCudnn6(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
class TestCudnnCase6(TestCase5):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "max"
if __name__ == '__main__':
unittest.main()
......@@ -3,8 +3,7 @@ import numpy as np
from op_test import OpTest
def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
......@@ -27,8 +26,7 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
return out
def avg_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
def avg_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
......@@ -55,6 +53,10 @@ def avg_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
class TestPool3d_Op(OpTest):
def setUp(self):
self.init_test_case()
self.init_global_pool()
self.init_op_type()
self.init_pool_type()
if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype("float32")
......@@ -81,74 +83,115 @@ class TestPool3d_Op(OpTest):
self.check_grad(set(['X']), 'Out', max_relative_error=0.07)
def init_test_case(self):
self.global_pool = True
self.op_type = "pool3d"
self.pool_type = "avg"
self.pool3D_forward_naive = avg_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [0, 0, 0]
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self):
self.pool_type = "avg"
self.pool3D_forward_naive = avg_pool3D_forward_naive
def init_global_pool(self):
self.global_pool = True
class TestCase1(TestPool3d_Op):
def init_test_case(self):
self.global_pool = False
self.op_type = "pool3d"
self.pool_type = "avg"
self.pool3D_forward_naive = avg_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [0, 0, 0]
class TestCase2(TestPool3d_Op):
def init_test_case(self):
self.global_pool = False
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self):
self.pool_type = "avg"
self.pool3D_forward_naive = avg_pool3D_forward_naive
def init_global_pool(self):
self.global_pool = False
class TestCase2(TestPool3d_Op):
def init_test_case(self):
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self):
self.pool_type = "avg"
self.pool3D_forward_naive = avg_pool3D_forward_naive
def init_global_pool(self):
self.global_pool = False
class TestCase3(TestPool3d_Op):
def init_test_case(self):
self.global_pool = True
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self):
self.pool_type = "max"
self.pool3D_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [0, 0, 0]
class TestCase4(TestPool3d_Op):
def init_test_case(self):
self.global_pool = False
class TestCase4(TestCase1):
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self):
self.pool_type = "max"
self.pool3D_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [0, 0, 0]
class TestCase5(TestPool3d_Op):
def init_test_case(self):
self.global_pool = False
class TestCase5(TestCase2):
def init_op_type(self):
self.op_type = "pool3d"
def init_pool_type(self):
self.pool_type = "max"
self.pool3D_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
#--------------------test pool3d_cudnn--------------------
class TestCudnnCase1(TestPool3d_Op):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
class TestCudnnCase2(TestCase1):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
class TestCudnnCase3(TestCase2):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
class TestCudnnCase4(TestCase3):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
class TestCudnnCase5(TestCase4):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
class TestCudnnCase6(TestCase5):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
if __name__ == '__main__':
......
import unittest
import paddle.v2.fluid.core as core
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.framework import g_main_program
......@@ -98,21 +97,26 @@ class TestProgram(unittest.TestCase):
"Y": add_y},
outputs={"Out": add_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": add_out}, outputs={"Out": mean_out})
self.assertEqual(mul_op.idx, 0)
self.assertEqual(add_op.idx, 1)
param_to_grad = prog.append_backward(add_out, set())
param_to_grad = prog.append_backward(mean_out, set())
def grad_name(name):
return name + "@GRAD"
for var_name in ("mul.x", "mul.y", "mul.out", "add.y", "add.out"):
for var_name in ("mul.x", "mul.y", "mul.out", "add.y", "add.out",
"mean.out"):
self.assertEqual(param_to_grad[var_name][0], grad_name(var_name))
self.assertEqual(param_to_grad[var_name][1], 0)
expect_ops = [
"mul", "elementwise_add", "fill_constant", "elementwise_add_grad",
"mul_grad"
"mul", "elementwise_add", "mean", "fill_constant", "mean_grad",
"elementwise_add_grad", "mul_grad"
]
actual_ops = []
for op in block.ops:
......
......@@ -29,7 +29,11 @@ class TestL2DecayRegularizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
params_grads = append_backward_ops(mul_out)
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
count_ops = len(block.ops)
params_grads = optimizer.append_regularization_ops(params_grads)
......@@ -62,7 +66,11 @@ class TestL1DecayRegularizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
params_grads = append_backward_ops(mul_out)
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
count_ops = len(block.ops)
params_grads = optimizer.append_regularization_ops(params_grads)
......
import unittest
import numpy as np
import sys
from op_test import OpTest
class TestSequenceSliceOp(OpTest):
def set_data(self):
self.init_test_case()
# only supprot one level LoD
x = np.random.random(self.x_dim).astype('float32')
lod = self.x_lod
offset = np.array(self.offset).astype("int64")
length = np.array(self.length).astype("int64")
self.inputs = {'X': (x, lod), 'Offset': offset, 'Length': length}
outs = [] #np.zeros((100, 3, 2)).astype('float32')
out_lod = [[0]]
out_lod_offset = 0
for i in range(len(offset)):
sub_x = x[lod[0][i] + offset[i, 0]: lod[0]
[i] + offset[i, 0] + length[i, 0], :]
out_lod_offset = out_lod_offset + len(sub_x)
outs.append(sub_x)
out_lod[0].append(out_lod_offset)
outs = np.concatenate(outs, axis=0)
self.outputs = {'Out': (outs, out_lod)}
def init_test_case(self):
self.x_dim = (100, 3, 2)
self.x_lod = [[0, 20, 40, 60, 80, 100]]
self.offset = [[1], [2], [3], [4], [5]]
self.length = [[10], [8], [6], [4], [2]]
def setUp(self):
self.op_type = "sequence_slice"
self.set_data()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册