提交 83dbc150 编写于 作者: Y Yu Yang

Merge branch 'develop' of github.com:baidu/Paddle into feature/complete_variable_bind

......@@ -302,7 +302,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; // empty vector
}
grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc);
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get());
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) {
......
......@@ -58,6 +58,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "A");
AddInput("Y", "B");
AddOutput("Out", "Out");
AddAttr<int>("x_num_col_dims", "").SetDefault(1).EqualGreaterThan(1);
AddAttr<int>("y_num_col_dims", "").SetDefault(1).EqualGreaterThan(1);
AddComment("Mul");
}
};
......@@ -440,6 +442,28 @@ TEST(Backward, simple_single_op) {
std::vector<std::string>({f::GradVarName("b")}));
}
TEST(Backward, default_attribute) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op = block->AppendOp();
op->SetType("mul");
op->SetInput("X", {"x"});
op->SetInput("Y", {"y"});
op->SetOutput("Out", {"out"});
AppendBackward(program, {});
ASSERT_EQ(block->AllOps().size(), 2UL);
EXPECT_EQ(boost::get<int>(op->GetAttr("x_num_col_dims")), 1);
EXPECT_EQ(boost::get<int>(op->GetAttr("y_num_col_dims")), 1);
f::OpDescBind *grad_op = block->AllOps()[1];
ASSERT_EQ(grad_op->Type(), "mul_grad");
EXPECT_EQ(boost::get<int>(grad_op->GetAttr("x_num_col_dims")), 1);
EXPECT_EQ(boost::get<int>(grad_op->GetAttr("y_num_col_dims")), 1);
}
TEST(Backward, simple_mult_op) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <deque>
#include <memory>
#include <unordered_map>
#include <vector>
#include "paddle/framework/op_desc.h"
......
......@@ -28,7 +28,6 @@ inline DataType ToDataType(std::type_index type) {
return DataType::INT32;
} else {
PADDLE_THROW("Not supported");
return static_cast<DataType>(-1);
}
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package paddle.framework;
enum AttrType {
......
......@@ -25,6 +25,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
inputs_ = inputs;
outputs_ = outputs;
attrs_ = attrs;
need_update_ = true;
}
OpDesc *OpDescBind::Proto() {
......
......@@ -52,8 +52,6 @@ class OpDescBind {
void SetOutput(const std::string &param_name,
const std::vector<std::string> &args);
std::string DebugString() { return this->Proto()->DebugString(); }
bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
}
......@@ -97,6 +95,11 @@ class OpDescBind {
const VariableNameMap &Outputs() const { return outputs_; }
AttributeMap *MutableAttrMap() {
this->need_update_ = true;
return &this->attrs_;
}
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
......@@ -60,9 +60,14 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
}
std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs(
const OpDescBind& op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc.Type());
return info.grad_op_maker_(op_desc);
OpDescBind* op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc->Type());
if (info.Checker() != nullptr) {
info.Checker()->Check(*op_desc->MutableAttrMap());
}
return info.grad_op_maker_(*op_desc);
}
} // namespace framework
......
......@@ -80,7 +80,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::vector<std::unique_ptr<OpDescBind>> CreateGradOpDescs(
const OpDescBind& op_desc);
OpDescBind* op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
};
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <vector>
#include "paddle/framework/framework.pb.h"
#include "paddle/platform/macros.h"
......@@ -31,8 +32,6 @@ class ProgramDescBind {
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
std::string DebugString() { return Proto()->DebugString(); }
size_t Size() const { return blocks_.size(); }
ProgramDesc *Proto();
......
......@@ -95,6 +95,19 @@ class Tensor {
template <typename T>
inline void CopyFrom(const Tensor& src, const platform::Place& dst_place);
/**
* @brief Copy the content of an external vector to a tensor.
*
* @param[in] src The external vector.
* @param[in] ctx The device context contains place where to store.
*
* * @note CopyFromVector assumes that the tensor has been resized
* before invoking.
*/
template <typename T>
inline void CopyFromVector(const std::vector<T>& src,
const platform::Place& dst_place);
/**
* @brief Return the slice of the tensor.
*
......
......@@ -123,6 +123,29 @@ inline void Tensor::CopyFrom(const Tensor& src,
#endif
}
template <typename T>
inline void Tensor::CopyFromVector(const std::vector<T>& src,
const platform::Place& dst_place) {
auto src_ptr = static_cast<const void*>(src.data());
platform::CPUPlace src_place;
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
auto size = src.size() * sizeof(T);
if (platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, src_place,
src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr, src_place,
src_ptr, size, 0);
}
PADDLE_ENFORCE(cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in Tensor CopyFromVector");
#endif
}
template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
check_memory_size<T>();
......
......@@ -263,6 +263,93 @@ TEST(Tensor, CopyFrom) {
#endif
}
TEST(Tensor, CopyFromVector) {
using namespace paddle::framework;
using namespace paddle::platform;
{
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Tensor cpu_tensor;
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace();
cpu_tensor.CopyFromVector<int>(src_vec, *cpu_place);
// Compare Tensors
const int* cpu_ptr = cpu_tensor.data<int>();
const int* src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
}
src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2}));
cpu_tensor.CopyFromVector<int>(src_vec, *cpu_place);
cpu_ptr = cpu_tensor.data<int>();
src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
for (size_t i = 0; i < 5; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
}
delete cpu_place;
}
#ifdef PADDLE_WITH_CUDA
{
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Tensor cpu_tensor;
Tensor gpu_tensor;
Tensor dst_tensor;
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace();
cpu_tensor.CopyFromVector<int>(src_vec, *cpu_place);
// Copy to GPUTensor
gpu_tensor.Resize(make_ddim({3, 3}));
auto gpu_place = new paddle::platform::GPUPlace();
gpu_tensor.CopyFromVector<int>(src_vec, *gpu_place);
// Copy from GPU to CPU tensor for comparison
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place);
// Compare Tensors
const int* src_ptr = src_vec.data();
const int* cpu_ptr = cpu_tensor.data<int>();
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, cpu_ptr);
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2}));
cpu_tensor.CopyFromVector<int>(src_vec, *cpu_place);
gpu_tensor.Resize(make_ddim({2, 2}));
gpu_tensor.CopyFromVector<int>(src_vec, *gpu_place);
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place);
src_ptr = src_vec.data();
cpu_ptr = cpu_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, cpu_ptr);
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 5; ++i) {
EXPECT_EQ(src_ptr[i], cpu_ptr[i]);
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
delete cpu_place;
delete gpu_place;
}
#endif
}
TEST(Tensor, ReshapeToMatrix) {
using namespace paddle::framework;
using namespace paddle::platform;
......
......@@ -15,6 +15,7 @@
#pragma once
#include <functional>
#include <map>
#include <memory>
#include "paddle/platform/variant.h"
namespace paddle {
......
......@@ -162,4 +162,4 @@ int main(int argc, char** argv) {
return RUN_ALL_TESTS();
}
#endif /* PADDLE_ONLY_CPU */
#endif
......@@ -182,7 +182,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
max_chunk_size_ = platform::GpuMaxChunkSize();
}
}
#endif // PADDLE_ONLY_CPU
#endif
// Allocate a new maximum sized block
size_t index = 0;
......
......@@ -134,7 +134,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) {
bool GPUAllocator::UseGpu() const { return true; }
#endif // PADDLE_ONLY_CPU
#endif
} // namespace detail
} // namespace memory
......
......@@ -51,7 +51,7 @@ class GPUAllocator : public SystemAllocator {
size_t gpu_alloc_size_ = 0;
size_t fallback_alloc_size_ = 0;
};
#endif // PADDLE_ONLY_CPU
#endif
} // namespace detail
} // namespace memory
......
......@@ -62,4 +62,4 @@ TEST(GPUAllocator, Alloc) {
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
#endif // PADDLE_ONLY_CPU
#endif
......@@ -89,7 +89,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
}
#endif // PADDLE_ONLY_CPU
#endif
} // namespace memory
} // namespace paddle
......@@ -53,7 +53,7 @@ template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream);
#endif // PADDLE_ONLY_CPU
#endif
} // namespace memory
} // namespace paddle
......@@ -111,7 +111,7 @@ size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
return GetGPUBuddyAllocator(place.device)->Used();
}
#endif // PADDLE_ONLY_CPU
#endif
} // namespace memory
} // namespace paddle
......@@ -135,4 +135,4 @@ TEST(BuddyAllocator, GPUMultAlloc) {
}
}
#endif // PADDLE_ONLY_CPU
#endif
......@@ -55,12 +55,20 @@ function(op_library TARGET)
set(pybind_flag 1)
endif()
# pool_op contains several operators
if ("${TARGET}" STREQUAL "pool_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif()
# pool_with_index_op contains several operators
if ("${TARGET}" STREQUAL "pool_with_index_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
endif()
# activation_op contains several operators
if ("${TARGET}" STREQUAL "activation_op")
set(pybind_flag 1)
......
......@@ -201,6 +201,27 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename AttrType>
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor) The input of ELU operator, it shouldn't be empty. Input "
"is flattened and treated as a 1D array.");
AddOutput("Y",
"(Tensor) The output of ELU operator. It has the same shape as "
"the input.");
AddAttr<AttrType>(
"alpha", "(float, default 1.0) Alpha value in the elu formulation.")
.SetDefault(static_cast<AttrType>(1.));
AddComment(R"DOC(
ELU activation operator. It applies this element-wise computation on
the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1)).
Check .. _Link: https://arxiv.org/abs/1511.07289 for more details.)DOC");
}
};
template <typename AttrType>
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
......@@ -289,6 +310,9 @@ REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>,
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad);
REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker<float>, elu_grad,
ops::ActivationOpGrad);
REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker<float>, relu6_grad,
ops::ActivationOpGrad);
......
......@@ -384,6 +384,35 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) =
x.cwiseMax(static_cast<T>(0)) +
(alpha * (x.exp() - static_cast<T>(1))).cwiseMin(static_cast<T>(0));
}
};
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) =
dy * (x > static_cast<T>(0)).template cast<T>() +
dy * (y + alpha) * (x < static_cast<T>(0)).template cast<T>();
}
};
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
......@@ -440,21 +469,22 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
} // namespace operators
} // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor)
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor)
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv_shift_op.h"
#include "paddle/framework/eigen.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
class ConvShiftOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The 1st dimension of Input(X) and Input(Y) should "
"be equal.");
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
"The 2nd dimension of Input(Y) should be odd.");
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
"The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X).");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class ConvShiftGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should be not null.");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(x_grad_name, x_dims);
}
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) {
auto y_dims = ctx->GetInputDim("Y");
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConvShiftOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
"where B is the batch size and M is the data dimension.");
AddInput("Y",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x N, "
"where B is the batch size and N is the data dimension. N must "
"be odd.");
AddOutput("Out",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
"i.e., the same shape as X.");
AddComment(R"DOC(
ConvShift Operator.
A layer for circular convolution of two vectors,
as used in the Neural Turing Machine: https://arxiv.org/abs/1410.5401
The equation is:
\f[
Out[i] = \sum_{j=-(N-1)/2}^{(N-1)/2} X_{i+j} * Y_{j}
\f]
where X's index is computed modulo M, and b's index is computed modulo N.
Both of the input `X` and `Y` can carry LoD (Level of Details) information.
However, the output only shares the LoD information with input `X`.
)DOC");
}
};
template <typename T>
class ConvShiftKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *X = context.Input<Tensor>("X");
auto *Y = context.Input<Tensor>("Y");
auto *Out = context.Output<Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
auto x = EigenMatrix<T>::From(*X);
auto y = EigenMatrix<T>::From(*Y);
auto out = EigenMatrix<T>::From(*Out);
out.setZero();
size_t batch_size = X->dims()[0];
size_t x_width = X->dims()[1];
size_t y_width = Y->dims()[1];
size_t y_half_width = (y_width - 1) / 2;
for (size_t k = 0; k < batch_size; ++k) {
for (size_t i = 0; i < x_width; ++i) {
for (size_t j = 0; j < y_width; ++j) {
int index = (i + j - y_half_width + x_width) % x_width;
out(k, i) += x(k, index) * y(k, j);
}
}
}
}
};
template <typename T>
class ConvShiftGradKernel<platform::CPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *X = context.Input<Tensor>("X");
auto *Y = context.Input<Tensor>("Y");
auto *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto *dY = context.Output<Tensor>(framework::GradVarName("Y"));
auto x = EigenMatrix<T>::From(*X);
auto y = EigenMatrix<T>::From(*Y);
auto dout = EigenMatrix<T>::From(*dOut);
auto x_dims = X->dims();
auto y_dims = Y->dims();
size_t batch_size = x_dims[0];
size_t x_width = x_dims[1];
size_t y_width = y_dims[1];
size_t y_half_width = (y_width - 1) / 2;
// The below trades code duplication for efficiency (keeping the if
// statement outside of the loop).
if (dX) {
dX->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::From(*dX);
dx.setZero();
for (size_t k = 0; k < batch_size; ++k) {
for (size_t i = 0; i < x_width; ++i) {
for (size_t j = 0; j < y_width; ++j) {
int index = (i + j - y_half_width + x_width) % x_width;
dx(k, index) += dout(k, i) * y(k, j);
}
}
}
}
if (dY) {
dY->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::From(*dY);
dy.setZero();
for (size_t k = 0; k < batch_size; ++k) {
for (size_t i = 0; i < x_width; ++i) {
for (size_t j = 0; j < y_width; ++j) {
int index = (i + j - y_half_width + x_width) % x_width;
dy(k, j) += x(k, index) * dout(k, i);
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker,
conv_shift_grad, ops::ConvShiftGradOp);
REGISTER_OP_CPU_KERNEL(conv_shift,
ops::ConvShiftKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv_shift_grad,
ops::ConvShiftGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv_shift_op.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
using framework::Tensor;
namespace {
inline int div_up(int x, int y) { return (x + y - 1) / y; }
// Some notes on the design:
//
// Each thread is responsible for computing a single output out[k, i].
// Thread blocks are based on tiles of x with height 1 in the batch dimension.
//
// This design is based on the typical use case where the filter
// y is fairly small. For large y, it would probably be more efficient
// to also tile across y.
template <typename T>
__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width,
int y_width, int y_half_width,
int batch_size) {
extern __shared__ T mem[];
int tx = threadIdx.x;
int i = blockIdx.x * blockDim.x + tx; // global x index
int k = blockIdx.y; // batch index
// Check if we are in a boundary block with fewer x's to process than
// blockDim.x.
int num_x =
(blockIdx.x == gridDim.x - 1) ? (x_width % blockDim.x) : blockDim.x;
T *sx = mem;
T *sx_pad = &mem[num_x];
T *sy = &mem[blockDim.x + y_width];
// Collaboratively load y[k, :] and length-y padding of x into shared memory.
int pad_start = blockIdx.x * blockDim.x + num_x + x_width - y_half_width;
for (int j = tx; j < y_width; j += blockDim.x) {
sy[j] = y[k * y_width + j];
sx_pad[j] = x[k * x_width + (pad_start + j) % x_width];
}
// Load a cyclically shifted slice of x into shared memory.
if (tx < num_x) {
int load_i = (i - y_half_width + x_width) % x_width;
sx[tx] = x[k * x_width + load_i];
} else {
return;
}
__syncthreads();
// Compute dot product of sx[tx:tx + y_width] and sy.
T sum = 0;
for (int j = 0; j < y_width; ++j) {
sum += sx[tx + j] * sy[j];
}
// Save to out[k, i].
out[k * x_width + i] = sum;
}
// Compute x gradient - initial naive implementation with atomic add.
template <typename T>
__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width,
int y_width, int y_half_width, int batch_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
int j = blockIdx.y; // y index
int k = blockIdx.z; // batch index
if (i < x_width) {
int index = (i + j - y_half_width + x_width) % x_width;
atomicAdd(&dx[k * x_width + index],
dout[k * x_width + i] * y[k * y_width + j]);
}
}
// Compute y gradient - initial naive implementation with atomic add.
template <typename T>
__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width,
int y_width, int y_half_width, int batch_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
int j = blockIdx.y; // y index
int k = blockIdx.z; // batch index
if (i < x_width) {
int index = (i + j - y_half_width + x_width) % x_width;
atomicAdd(&dy[k * y_width + j],
x[k * x_width + index] * dout[k * x_width + i]);
}
}
} // namespace
template <typename T>
class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Y = context.Input<Tensor>("Y");
Tensor *Out = context.Output<Tensor>("Out");
const T *x_data = X->data<T>();
const T *y_data = Y->data<T>();
T *out_data = Out->mutable_data<T>(context.GetPlace());
int batch_size = X->dims()[0];
int x_width = X->dims()[1];
int y_width = Y->dims()[1];
int y_half_width = (y_width - 1) / 2;
const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block);
int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T);
dim3 grid_dim(num_x_blocks, batch_size);
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
context.device_context())
.stream();
conv_shift_forward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>(
x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size);
}
};
template <typename T>
class ConvShiftGradKernel<platform::GPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Y = context.Input<Tensor>("Y");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
const T *x_data = X->data<T>();
const T *y_data = Y->data<T>();
const T *dout_data = dOut->data<T>();
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
Tensor *dY = context.Output<Tensor>(framework::GradVarName("Y"));
int batch_size = X->dims()[0];
int x_width = X->dims()[1];
int y_width = Y->dims()[1];
int y_half_width = (y_width - 1) / 2;
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
context.device_context())
.stream();
const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block);
dim3 grid_dim(num_x_blocks, y_width, batch_size);
if (dX) {
T *dx_data = dX->mutable_data<T>(context.GetPlace());
cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream);
conv_shift_dx<T><<<grid_dim, x_per_block, 0, stream>>>(
dout_data, y_data, dx_data, x_width, y_width, y_half_width,
batch_size);
}
if (dY) {
T *dy_data = dY->mutable_data<T>(context.GetPlace());
cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream);
conv_shift_dy<T><<<grid_dim, x_per_block, 0, stream>>>(
x_data, dout_data, dy_data, x_width, y_width, y_half_width,
batch_size);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv_shift,
ops::ConvShiftKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv_shift_grad,
ops::ConvShiftGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class ConvShiftKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
};
template <typename Place, typename T>
class ConvShiftGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fill_constant_op.h"
namespace paddle {
namespace operators {
class FillConstantOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillConstantOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto dims = framework::make_ddim(shape_int64);
ctx->SetOutputDim("Out", dims);
}
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("dataType"));
}
};
class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillConstantOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dataType",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddComment(R"DOC(Fill up a variable with specified constant value.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp,
ops::FillConstantOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_constant,
ops::FillConstantOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_constant_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
fill_constant,
ops::FillConstantOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class FillConstantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<T>("value");
auto out_eigen = framework::EigenVector<T>::Flatten(*out);
auto place = ctx.GetEigenDevice<Place>();
out_eigen.device(place) = out_eigen.constant(static_cast<T>(value));
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
class InterpOp : public NetOp {
public:
InterpOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName,
"Input(X) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Input("Y"), framework::kEmptyVarName,
"Input(Y) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName,
"Input(W) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("SubOut"), framework::kEmptyVarName,
"Output(SubOut) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName,
"Output(MulOut) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
"Output(Out) of InterpOp should not be null.");
// SubOut = X - Y
auto x = Input("X");
auto y = Input("Y");
auto sub_out = Output("SubOut");
AppendOp(framework::OpRegistry::CreateOp(
"elementwise_sub", {{"X", {x}}, {"Y", {y}}}, {{"Out", {sub_out}}}, {}));
// MulOut = SubOut * W = (X - Y) * W
auto w = Input("W");
auto mul_out = Output("MulOut");
AppendOp(framework::OpRegistry::CreateOp(
"elementwise_mul", {{"X", {sub_out}}, {"Y", {w}}}, {{"Out", {mul_out}}},
{{"axis", 0}}));
// Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W)
AppendOp(framework::OpRegistry::CreateOp("elementwise_add",
{{"X", {mul_out}}, {"Y", {y}}},
{{"Out", {Output("Out")}}}, {}));
CompleteAddOp(false);
}
};
class InterpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor), 2-D Matrix of shape [batch_size, data_dim]"
"containing data samples, the first input of interp_op");
AddInput("Y",
"(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`"
"containing data samples, the second input of interp_op");
AddInput("W",
"(Tensor), 1-D Vector of shape [batch_size],"
"the interpolated values in the half-open interval [0.0, 1.0)");
AddOutput("SubOut",
"(Tensor), the intermediate subtraction outputs, saving X - Y.")
.AsIntermediate();
AddOutput("MulOut",
"(Tensor), the intermediate multiplication outputs,"
"saving the elementwise multiplication of (X - Y) and W.")
.AsIntermediate();
AddOutput("Out",
"(Tensor), the output of interp_op, same shape with X,"
"returns the first-dimensional piecewise linear interpolant "
"between X and Y");
AddComment(R"DOC(
Linear Interpolation with two inputs, used in NEURAL TURING MACHINE.
Equation:
Out.row[i] = X.row[i] * W[i] + Y.row[i] * (1 - W[i])
= (X.row[i] - Y.row[i]) * W[i] + Y.row[i]
Example:
X = [[1,2],[3,4]],
Y = [[2,1],[4,3]],
W = [0.3, 0.4]
Then, Out = [[1.7,1.3],[3.6,3.4]]
where 1.7 = 1*0.3+2*(1-0.3),
1.3 = 2*0.3+1*(1-0.3),
3.6 = 3*0.4+4*(1-0.4),
3.4 = 4*0.4+3*(1-0.4)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(interp, ops::InterpOp, ops::InterpOpMaker);
......@@ -18,6 +18,11 @@ namespace paddle {
namespace operators {
namespace math {
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename PoolProcess, typename T>
class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
public:
......@@ -73,6 +78,11 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent height
* and width, respectively.
*/
template <typename PoolProcess, class T>
class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public:
......@@ -135,6 +145,11 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <class T>
class MaxPool2dGradFunctor<platform::CPUPlace, T> {
public:
......@@ -197,7 +212,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> {
};
template class MaxPool2dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::MaxPool<float>, float>;
......@@ -216,6 +231,11 @@ template class Pool2dGradFunctor<
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename PoolProcess, class T>
class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
public:
......@@ -286,6 +306,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename PoolProcess, class T>
class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public:
......@@ -364,6 +389,11 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <class T>
class MaxPool3dGradFunctor<platform::CPUPlace, T> {
public:
......@@ -440,7 +470,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> {
};
template class MaxPool3dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::MaxPool<float>, float>;
......@@ -458,6 +488,253 @@ template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
T ele = static_cast<T>(-FLT_MAX);
int index = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele < input_data[h * input_width + w]) {
ele = input_data[h * input_width + w];
index = h * input_width + w;
}
}
}
output_data[ph * output_width + pw] = ele;
mask_data[ph * output_width + pw] = index;
}
}
// offset
input_data += input_stride;
output_data += output_stride;
mask_data += output_stride;
}
}
}
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input_grad.dims()[0];
const int input_height = input_grad.dims()[2];
const int input_width = input_grad.dims()[3];
const int output_channels = output_grad.dims()[1];
const int output_height = output_grad.dims()[2];
const int output_width = output_grad.dims()[3];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
for (int pw = 0; pw < output_width; ++pw) {
const int output_idx = ph * output_width + pw;
const int input_idx = static_cast<int>(mask_data[output_idx]);
input_grad_data[input_idx] += output_grad_data[output_idx];
}
}
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
}
}
}
};
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>;
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = static_cast<T>(-FLT_MAX);
int index = -1;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_idx = (d * input_height + h) * input_width + w;
if (ele < input_data[input_idx]) {
index = input_idx;
ele = input_data[input_idx];
}
}
}
}
output_data[output_idx] = ele;
mask_data[output_idx] = index;
}
}
}
// offset
input_data += input_stride;
output_data += output_stride;
mask_data += output_stride;
}
}
}
};
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input_grad.dims()[0];
const int input_depth = input_grad.dims()[2];
const int input_height = input_grad.dims()[3];
const int input_width = input_grad.dims()[4];
const int output_channels = output_grad.dims()[1];
const int output_depth = output_grad.dims()[2];
const int output_height = output_grad.dims()[3];
const int output_width = output_grad.dims()[4];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
for (int ph = 0; ph < output_height; ++ph) {
for (int pw = 0; pw < output_width; ++pw) {
const int output_idx =
(pd * output_height + ph) * output_width + pw;
const int input_idx = static_cast<int>(mask_data[output_idx]);
input_grad_data[input_idx] += output_grad_data[output_idx];
}
}
}
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
}
}
}
};
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
此差异已折叠。
......@@ -21,15 +21,27 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
//////////////////////
#define FLT_MAX __FLT_MAX__ //
#define FLT_MAX \
__FLT_MAX__ // It might need to be placed in another file, but I'm still
// wondering where to put it.
/*
* \brief Extracting simple operations from pooling.
* Both MaxPool and AvgPool need "initial", "compute" and "finalize"
* operation.
* MaxPool initializes temp variable to the negative maximum to find the
* maximum value in the pooling field.
* AvgPool initializes temp variable to the zero to accumulate all values
* in pool pooling, and finally takes the average.
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
*/
template <class T>
class MaxPool {
public:
DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; }
DEVICE inline void finalize(T& y, const T& poo_size) {}
DEVICE inline void finalize(T& y, const T& pool_field) {}
};
template <class T>
......@@ -37,8 +49,9 @@ class AvgPool {
public:
DEVICE inline T initial() { return static_cast<T>(0); }
DEVICE inline void compute(T& y, const T& x) { y += x; }
DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; }
DEVICE inline void finalize(T& y, const T& pool_field) { y /= pool_field; }
};
template <class T>
class MaxPoolGrad {
public:
......@@ -57,6 +70,20 @@ class AvgPoolGrad {
}
};
/*
* \brief Getting pooling results, and calculating gradient.
*
* In pool2d, all tensors are in NCHW format. Where N is batch size, C is the
* number of channels, H and W is the height and width of feature.
* In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the
* number of channels, D, H and W is the depth, height and width of feature.
*
* In max pooling, it is possible that the pooling region has multiple maximum
* elements. In this case, we should compute the gradient of the first maximum
* element.
* This is different from average pooling. So we rewrite the max_pool_grad:
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/
template <typename Place, typename PoolProcess, typename T>
class Pool2dFunctor {
public:
......@@ -117,6 +144,51 @@ class MaxPool3dGradFunctor {
std::vector<int>& strides, std::vector<int>& paddings);
};
/*
* \brief Getting max pooling results and corresponding max index, and
* calculating gradient.
* In up-sampling-pooling, it is necessary to know max element index.
* In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in
* NCDHW format.
*/
template <typename Place, typename T>
class MaxPool2dWithIndexFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
template <typename Place, typename T>
class MaxPool2dWithIndexGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
template <typename Place, typename T>
class MaxPool3dWithIndexFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
template <typename Place, typename T>
class MaxPool3dWithIndexGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <set>
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_registry.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_with_index_op.h"
namespace paddle {
namespace operators {
inline int OutputSizeMaxPool(int input_size, int filter_size, int padding,
int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Mask"),
"Mask(Output) of Pooling should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D");
if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
"Intput size and pooling size should be consistent.");
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
"Strides size and pooling size should be the same.");
PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
"Paddings size and pooling size should be the same.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(OutputSizeMaxPool(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->SetOutputDim("Mask", framework::make_ddim(output_shape));
}
};
class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool2dWithIndexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of image.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of image.");
AddOutput("Mask",
"The Mask tensor of pooling operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is the number of channels, H and W "
"is the height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>(
"ksize",
"The pooling size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"Strides(height, width) of pooling operator."
"Default {1,1}.")
.SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>("paddings",
"Paddings(height, width) of pooling operator."
"Default {0,0}.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddComment(R"DOC(
The maxPooling2d with index operation calculates the output and the mask
based on the input and ksize, strides, paddings parameters. Input(X) and
output(Out, Mask) are in NCHW format. Where N is batch size, C is the
number of channels, H and W is the height and width of feature.
Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively.
)DOC");
}
};
class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool3dWithIndexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and width of "
"image.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of image.");
AddOutput("Mask",
"The Mask tensor of pooling operator."
"The format of output tensor is also NCDHW."
"Where N is batch size, C is the number of channels, D, H and W "
"is the depth, height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>(
"ksize",
"The pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
"Strides(depth, height, width) of pooling operator."
"Default {1,1,1}.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>(
"paddings",
"Paddings(depth, height, width) of pooling operator."
"Default {0,0,0}.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddComment(R"DOC(
The maxpooling3d with index operation calculates the output and the mask
based on the input and ksize, strides, paddings parameters.
Input(X) and output(Out, Mask) are in NCDHW format. Where N is batch
size, C is the number of channels, D, H and W is the depth, height and
width of feature. Parameters(ksize, strides, paddings) are three elements.
These three elements represent depth, height and width, respectively.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPool2dWithIndexOpMaker, max_pool2d_with_index_grad,
ops::MaxPoolWithIndexOpGrad);
REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad,
ops::MaxPoolWithIndexOpGrad);
REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_with_index_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>)
REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
Tensor* mask = context.Output<Tensor>("Mask");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
}
switch (ksize.size()) {
case 2: {
paddle::operators::math::MaxPool2dWithIndexFunctor<Place, T>
pool2d_forward;
pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize,
strides, paddings);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexFunctor<Place, T>
pool3d_forward;
pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize,
strides, paddings);
} break;
}
}
};
template <typename Place, typename T>
class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* mask = context.Input<Tensor>("Mask");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
}
}
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
switch (ksize.size()) {
case 2: {
paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T>
pool2d_backward;
pool2d_backward(context.device_context(), *in_x_grad, *out_grad,
*mask, ksize, strides, paddings);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T>
pool3d_backward;
pool3d_backward(context.device_context(), *in_x_grad, *out_grad,
*mask, ksize, strides, paddings);
} break;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -136,7 +136,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
#endif // PADDLE_ONLY_CPU
#endif
} // namespace platform
} // namespace paddle
......@@ -41,7 +41,7 @@ limitations under the License. */
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#endif // PADDLE_ONLY_CPU
#endif
namespace paddle {
namespace platform {
......
......@@ -63,4 +63,4 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
} // namespace platform
} // namespace paddle
#endif // PADDLE_ONLY_CPU
#endif
......@@ -117,7 +117,6 @@ void BindProgramDesc(py::module &m) {
.def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference)
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("__str__", &ProgramDescBind::DebugString)
.def("num_blocks", &ProgramDescBind::Size);
}
......@@ -193,8 +192,6 @@ void BindOpDesc(py::module &m) {
.def("output", &OpDescBind::Output)
.def("output_names", &OpDescBind::OutputNames)
.def("set_output", &OpDescBind::SetOutput)
.def("__str__", &OpDescBind::DebugString)
.def("__repr__", &OpDescBind::DebugString)
.def("has_attr", &OpDescBind::HasAttr)
.def("attr_type", &OpDescBind::GetAttrType)
.def("attr_names", &OpDescBind::AttrNames)
......
......@@ -181,6 +181,26 @@ class TestSoftRelu(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestELU(OpTest):
def setUp(self):
self.op_type = "elu"
x = np.random.uniform(-3, 3, [4, 4]).astype("float32")
alpha = 1.
# Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1)
# is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here
self.inputs = {'X': x}
self.attrs = {'alpha': alpha}
self.outputs = {
'Y': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestReciprocal(OpTest):
def setUp(self):
self.op_type = "reciprocal"
......
import unittest
import numpy as np
from op_test import OpTest
def conv_shift_forward(x, y):
out = np.zeros_like(x)
M = x.shape[1]
N = y.shape[1]
y_half_width = (N - 1) / 2
for i in xrange(M):
for j in xrange(N):
out[:, i] += x[:, (i + j + M - y_half_width) % M] * y[:, j]
return out
class TestConvShiftOp(OpTest):
def setUp(self):
self.op_type = "conv_shift"
batch_size = 4
x_dim = 17
y_dim = 3 # must be odd and <= x_dim
x = np.random.random((batch_size, x_dim)).astype("float32")
y = np.random.random((batch_size, y_dim)).astype("float32")
self.inputs = {'X': x, 'Y': y}
out = conv_shift_forward(x, y)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05)
def test_check_grad_ignore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X"))
def test_check_grad_ignore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y'))
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestFillConstantOp1(OpTest):
def setUp(self):
'''Test fill_constant op with specified value
'''
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {'shape': [123, 92], 'value': 3.8}
self.outputs = {'Out': np.full((123, 92), 3.8)}
def test_check_output(self):
self.check_output()
class TestFillConstantOp2(OpTest):
def setUp(self):
'''Test fill_constant op with default value
'''
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {'shape': [123, 92]}
self.outputs = {'Out': np.full((123, 92), 0.0)}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestInterpOp(OpTest):
def setUp(self):
self.op_type = "interp"
x = np.random.random((2, 3)).astype("float32")
y = np.random.random((2, 3)).astype("float32")
w = np.random.random(2).astype("float32")
sub_out = x - y
mul_out = sub_out * w.reshape(2, 1)
out = mul_out + y
self.inputs = {'X': x, 'Y': y, 'W': w}
self.outputs = {'Out': out, 'SubOut': sub_out, 'MulOut': mul_out}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
def max_pool3D_forward_naive(x,
ksize,
strides,
paddings=[0, 0, 0],
global_pool=0):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
mask = np.zeros((N, C, D_out, H_out, W_out))
for k in xrange(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
for i in xrange(H_out):
h_start = np.max((i * strides[0] - paddings[0], 0))
h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
for j in xrange(W_out):
w_start = np.max((j * strides[1] - paddings[1], 0))
w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :, :]
index = np.where(arr == np.max(arr))
sub_deep = index[0][0]
sub_row = index[1][0]
sub_col = index[2][0]
index = ((d_start + sub_deep) * H +
(h_start + sub_row)) * W + w_start + sub_col
mask[n, c, k, i, j] = index
return out, mask
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
mask = np.zeros((N, C, H_out, W_out))
for i in xrange(H_out):
for j in xrange(W_out):
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :]
index = np.where(arr == np.max(arr))
sub_row = index[0][0]
sub_col = index[1][0]
index = (r_start + sub_row) * W + c_start + sub_col
mask[n, c, i, j] = index
return out, mask
class TestMaxPoolWithIndex_Op(OpTest):
def setUp(self):
self.initTestCase()
input = np.random.random(self.shape).astype("float32")
output, mask = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool)
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'globalPooling': self.global_pool,
}
self.inputs = {'X': input}
self.outputs = {'Out': output, "Mask": mask}
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.check_grad(set(['X']), ['Out'], max_relative_error=0.07)
def initTestCase(self):
self.global_pool = True
self.index = "max_pool3d_with_index"
self.op_type = "%s" % self.index
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase1(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase2(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase3(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [2, 2, 2]
self.paddings = [0, 0, 0]
class TestCase4(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase5(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [2, 2, 2]
self.paddings = [0, 0, 0]
class TestCase6(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
class TestCase7(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [2, 2]
self.paddings = [0, 0]
class TestCase8(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
class TestCase9(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [2, 2]
self.paddings = [0, 0]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册