提交 2f122561 编写于 作者: Y Yibing Liu

Merge branch 'develop' of upstream into margin_rank_loss_op_dev

# Design Doc: ProgramDesc
The basic structure of a PaddlePaddle program is some nested blocks, as a C++ or Java program.
As described in [graph.md](./graph.md), the first five lines of the following PaddlePaddle program
```python
x = layer.data("images")
l = layer.data("label")
y = layer.fc(x)
cost = layer.mse(y, l)
optimize(cost)
train(cost, reader=mnist.train())
```
generates, or compiles, a PaddelPaddle program, which is represented by the following protobuf message:
```protobuf
message ProgramDesc {
repeated BlockDesc blocks = 1;
}
message BlockDesc {
required int32 parent = 1;
repeated VarDesc vars = 2;
repeated OpDesc ops = 3;
}
message OpDesc {
AttrDesc attrs = 1;
...
}
message AttrDesc {
required AttrType type = 1;
// index into ProgramDesc::blocks when type==BLOCK
optional int32 block = 2;
...
}
```
When each of the first five lines runs, related Python function, e.g., `layer.fc`, calls C++ InferShape functions. This InferShape function needs to access the properties of VarDesc's accessed by the current OpDesc. These VarDesc's might not be defined in the current block, but in some ancestor blocks. This requires that we can trace the parent of a block.
A nested block is often an attribute of an operator, most likely, an IfElseOp or a WhileOp. In above solution, all blocks are in `ProgramDesc::blocks`, this implicitly assigns a zero-based ID to each block -- the index of the block in `ProgramDesc::blocks`. So that `AttrDesc::block` could be an integer block ID.
With this design, the InferShape function should take the following parameters:
```c++
void InferShape(int current_block,
int current_operator,
ProgramDesc* program // might change VarDesc values.
) {
...
}
```
where
- `current_block` indices into `ProgramDesc::blocks`,
- `current_operator` indices into `BlockDesc::ops`.
......@@ -54,9 +54,9 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
AddOutput("Out", "(Tensor), 2D tensor of size (M x N)");
AddComment(R"DOC(
Two Element Mul Operator.
The equation is: Out = X * Y
......@@ -72,7 +72,7 @@ The equation is: Out = X * Y
构造函数里通过`AddInput`添加输入参数,通过`AddOutput`添加输出参数,通过`AddComment`添加Op的注释。这些函数会将对应内容添加到`OpProto`中。
上面的代码在`MulOp`中添加两个输入`X``Y`,添加了一个输出`Out`,并解释了各自含义,命名请遵守命名规范
上面的代码在`MulOp`中添加两个输入`X``Y`,添加了一个输出`Out`,并解释了各自含义,命名请遵守[命名规范](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md)
再以[`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L37)为例:
......
......@@ -19,12 +19,14 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
......
......@@ -292,5 +292,13 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) {
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
DDim stride(const DDim& ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1];
}
return framework::make_ddim(strides);
}
} // namespace framework
} // namespace paddle
......@@ -121,6 +121,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims);
DDim flatten_to_1d(const DDim& src);
DDim stride(const DDim& ddim);
} // namespace framework
} // namespace paddle
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
const std::string& name, const std::string& comment) {
auto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{input};
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
const std::string& name, const std::string& comment) {
auto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{output};
}
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) {
checker(attr.name());
}
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
}
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
namespace paddle {
namespace framework {
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
virtual ~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}
void Validate();
protected:
struct VariableBuilder {
OpProto::Var* var_;
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
VariableBuilder& NotInGradient() {
var_->set_not_in_gradient(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name, const std::string& comment);
VariableBuilder AddOutput(const std::string& name,
const std::string& comment);
template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
bool generated = false) {
auto* attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
attr->set_type(AttrTypeID<T>());
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
private:
void CheckNoDuplicatedInOutAttrs();
OpProto* proto_;
OpAttrChecker* op_checker_;
bool validated_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_proto_maker.h"
#include "gtest/gtest.h"
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestAttrProtoMaker(paddle::framework::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op");
AddAttr<float>("scale", "scale of test op");
}
};
TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestInOutProtoMaker(paddle::framework::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddInput("input", "input of test op");
}
};
TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
\ No newline at end of file
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
......
......@@ -228,43 +228,5 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
return res;
}
void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
const std::string& name, const std::string& comment) {
auto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{input};
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
const std::string& name, const std::string& comment) {
auto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{output};
}
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) {
checker(attr.name());
}
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
}
}
} // namespace framework
} // namespace paddle
......@@ -167,71 +167,6 @@ class NOP : public OperatorBase {
}
};
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}
void Validate();
protected:
struct VariableBuilder {
OpProto::Var* var_;
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
VariableBuilder& NotInGradient() {
var_->set_not_in_gradient(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name, const std::string& comment);
VariableBuilder AddOutput(const std::string& name,
const std::string& comment);
template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
bool generated = false) {
auto* attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
attr->set_type(AttrTypeID<T>());
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
private:
void CheckNoDuplicatedInOutAttrs();
OpProto* proto_;
OpAttrChecker* op_checker_;
bool validated_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
class InferShapeContext {
public:
InferShapeContext(const OperatorBase& op, const Scope& scope)
......
......@@ -264,37 +264,3 @@ TEST(Operator, Clone) {
auto b = a.Clone();
ASSERT_EQ(a.Type(), b->Type());
}
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestAttrProtoMaker(paddle::framework::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op");
AddAttr<float>("scale", "scale of test op");
}
};
TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestInOutProtoMaker(paddle::framework::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddInput("input", "input of test op");
}
};
TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
\ No newline at end of file
......@@ -58,6 +58,8 @@ class Scope {
/// nullptr if cannot find.
Variable* FindVar(const std::string& name) const;
const Scope& parent() const { return *parent_; }
/// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const;
......
......@@ -130,15 +130,19 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE_LT(begin_idx, end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
size_t base = numel() / dims_[0];
Tensor dst;
dst.holder_ = holder_;
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst;
if (dims_[0] == 1) {
return *this;
} else {
size_t base = numel() / dims_[0];
Tensor dst;
dst.holder_ = holder_;
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst;
}
}
inline Tensor& Tensor::Resize(const DDim& dims) {
......
......@@ -55,6 +55,13 @@ function(op_library TARGET)
set(pybind_flag 1)
endif()
# activation_op contains several operators
if ("${TARGET}" STREQUAL "activation_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif()
# pybind USE_NO_KERNEL_OP
file(READ ${TARGET}.cc TARGET_CONTENT)
string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}")
......@@ -96,3 +103,4 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/activation_op.h"
namespace paddle {
namespace operators {
class ActivationOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::LoDTensor>("Y")->Resize(
ctx.Input<framework::Tensor>("X")->dims());
}
};
class ActivationOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
->Resize(ctx.Input<framework::Tensor>("Y")->dims());
}
};
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sigmoid operator");
AddOutput("Y", "Output of Sigmoid operator");
AddComment("Sigmoid activation operator, sigmoid = 1 / (1 + exp(-x))");
}
};
class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Exp operator");
AddOutput("Y", "Output of Exp operator");
AddComment("Exp activation operator, exp(x) = e^x");
}
};
class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator");
AddOutput("Y", "Output of Relu operator");
AddComment("Relu activation operator, relu(x) = max(x, 0)");
}
};
class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Tanh operator");
AddOutput("Y", "Output of Tanh operator");
AddComment(
"Tanh activation operator, tanh = (exp(x) - exp(-x)) / (exp(x) + "
"exp(-x))");
}
};
class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sqrt operator");
AddOutput("Y", "Output of Sqrt operator");
AddComment("Sqrt activation operator, sqrt(x) = x^(1/2)");
}
};
class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AbsOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Abs operator");
AddOutput("Y", "Output of Abs operator");
AddComment("Abs activation operator, abs(x) = |x|");
}
};
class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReciprocalOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Reciprocal operator");
AddOutput("Y", "Output of Reciprocal operator");
AddComment("Reciprocal activation operator, reciprocal(x) = 1 / x");
}
};
class LogOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LogOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Log operator");
AddOutput("Y", "Output of Log operator");
AddComment("Log activation operator, log(x) = natural logarithm of x");
}
};
class SquareOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SquareOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Square operator");
AddOutput("Y", "Output of Square operator");
AddComment("Square activation operator, square(x) = x^2");
}
};
template <typename AttrType>
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator");
AddOutput("Y", "Output of BRelu operator");
AddComment("BRelu activation operator, brelu = max(min(x, t_min), t_max)");
AddAttr<AttrType>("t_min", "The min marginal value of BRelu")
.SetDefault(static_cast<AttrType>(0));
AddAttr<AttrType>("t_max", "The max marginal value of BRelu")
.SetDefault(static_cast<AttrType>(24));
}
};
template <typename AttrType>
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftReluOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator");
AddOutput("Y", "Output of SoftRelu operator");
AddComment(
"SoftRelu activation operator, soft_relu = log(1 + exp(max(min(x, "
"threshold), threshold)))");
AddAttr<AttrType>("threshold", "The threshold value of SoftRelu")
.SetDefault(static_cast<AttrType>(40));
}
};
template <typename AttrType>
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PowOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator");
AddOutput("Y", "Output of Pow operator");
AddComment("Pow activation operator, pow(x, factor) = x^factor");
AddAttr<AttrType>("factor", "The exponential factor of Pow")
.SetDefault(static_cast<AttrType>(1));
}
};
template <typename AttrType>
class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
STanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator");
AddOutput("Y", "Output of STanh operator");
AddComment("STanh activation operator, stanh = b * tanh(a * x)");
AddAttr<AttrType>("scale_a", "The scale parameter of a for the input")
.SetDefault(static_cast<AttrType>(2 / 3));
AddAttr<AttrType>("scale_b", "The scale parameter of b for the input")
.SetDefault(static_cast<AttrType>(1.7159));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::SigmoidFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SigmoidGradFunctor<float>>);
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
exp,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::ExpFunctor>);
REGISTER_OP_CPU_KERNEL(exp_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::ExpGradFunctor>);
REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::ReluFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::ReluGradFunctor<float>>);
REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
tanh,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::TanhFunctor>);
REGISTER_OP_CPU_KERNEL(
tanh_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::TanhGradFunctor<float>>);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
sqrt,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::SqrtFunctor>);
REGISTER_OP_CPU_KERNEL(
sqrt_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SqrtGradFunctor<float>>);
REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
abs,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::AbsFunctor>);
REGISTER_OP_CPU_KERNEL(abs_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::AbsGradFunctor>);
REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker,
reciprocal_grad, ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(reciprocal,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::ReciprocalFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
reciprocal_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::ReciprocalGradFunctor<float>>);
REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
log,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::LogFunctor>);
REGISTER_OP_CPU_KERNEL(
log_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::LogGradFunctor<float>>);
REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::SquareFunctor>);
REGISTER_OP_CPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(brelu_grad,
ops::BReluGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(soft_relu,
ops::SoftReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(pow, ops::PowKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pow_grad,
ops::PowGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(stanh,
ops::STanhKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(stanh_grad,
ops::STanhGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/activation_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sigmoid,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::SigmoidFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SigmoidGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
exp,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::ExpFunctor>);
REGISTER_OP_GPU_KERNEL(exp_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace,
float, ops::ExpGradFunctor>);
REGISTER_OP_GPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::ReluFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::ReluGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
tanh,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::TanhFunctor>);
REGISTER_OP_GPU_KERNEL(
tanh_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::TanhGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
sqrt,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::SqrtFunctor>);
REGISTER_OP_GPU_KERNEL(
sqrt_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SqrtGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
abs,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::AbsFunctor>);
REGISTER_OP_GPU_KERNEL(abs_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace,
float, ops::AbsGradFunctor>);
REGISTER_OP_GPU_KERNEL(reciprocal,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::ReciprocalFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
reciprocal_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::ReciprocalGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
log,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::LogFunctor>);
REGISTER_OP_GPU_KERNEL(
log_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::LogGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::SquareFunctor>);
REGISTER_OP_GPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(brelu_grad,
ops::BReluGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(soft_relu,
ops::SoftReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pow_grad,
ops::PowGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(stanh,
ops::STanhKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(stanh_grad,
ops::STanhGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T, typename Functor>
class ActivationKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
Functor functor;
functor(place, x, y);
}
};
template <typename Place, typename T, typename Functor>
class ActivationGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
Functor functor;
functor(place, x, y, dy, dx);
}
};
// sigmoid(x) = 1 / (1 + exp(-x))
template <typename T>
struct SigmoidFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
}
};
template <typename T>
struct SigmoidGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * y * (static_cast<T>(1) - y);
}
};
// exp(x) = e^x
struct ExpFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.exp();
}
};
struct ExpGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * y;
}
};
// relu(x) = max(x, 0)
template <typename T>
struct ReluFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.cwiseMax(static_cast<T>(0));
}
};
template <typename T>
struct ReluGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>();
}
};
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
struct TanhFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.tanh();
}
};
template <typename T>
struct TanhGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * (static_cast<T>(1) - y * y);
}
};
// sqrt(x) = x^(1/2)
struct SqrtFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.sqrt();
}
};
template <typename T>
struct SqrtGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
const Y y_conj = Eigen::numext::conj(y);
dx.device(d) = static_cast<T>(0.5) * dy / y_conj;
}
};
// abs(x) = |x|
struct AbsFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.abs();
}
};
struct AbsGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * x.sign();
}
};
// reciprocal(x) = 1 / x
template <typename T>
struct ReciprocalFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = static_cast<T>(1) / x;
}
};
template <typename T>
struct ReciprocalGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * static_cast<T>(-1) * y * y;
}
};
// log(x) = natural logarithm of x
struct LogFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.log();
}
};
template <typename T>
struct LogGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * (static_cast<T>(1) / x);
}
};
// square(x) = x^2
struct SquareFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.square();
}
};
template <typename T>
struct SquareGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * static_cast<T>(2) * x;
}
};
template <typename Place, typename T, typename AttrType = T>
class BReluKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto t_min = static_cast<T>(context.Attr<AttrType>("t_min"));
auto t_max = static_cast<T>(context.Attr<AttrType>("t_max"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = x.cwiseMax(t_min).cwiseMin(t_max);
}
};
template <typename Place, typename T, typename AttrType = T>
class BReluGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto t_min = static_cast<T>(context.Attr<AttrType>("t_min"));
auto t_max = static_cast<T>(context.Attr<AttrType>("t_max"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
dx.device(place) = dy * ((x > t_min) * (x < t_max)).template cast<T>();
}
};
template <typename Place, typename T, typename AttrType = T>
class SoftReluKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto threshold = static_cast<T>(context.Attr<AttrType>("threshold"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
auto temp = x.cwiseMax(-threshold).cwiseMin(threshold).eval();
y.device(place) = (static_cast<T>(1) + temp.exp()).log();
}
};
template <typename Place, typename T, typename AttrType = T>
class SoftReluGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto threshold = static_cast<T>(context.Attr<AttrType>("threshold"));
dX->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
auto temp = ((x > -threshold) * (x < threshold)).template cast<T>().eval();
dx.device(place) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
}
};
template <typename Place, typename T, typename AttrType = T>
class PowKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto factor = static_cast<T>(context.Attr<AttrType>("factor"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = x.pow(factor);
}
};
template <typename Place, typename T, typename AttrType = T>
class PowGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto factor = static_cast<T>(context.Attr<AttrType>("factor"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
dx.device(place) = dy * factor * x.pow(factor - static_cast<T>(1));
}
};
template <typename Place, typename T, typename AttrType = T>
class STanhKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto scale_a = static_cast<T>(context.Attr<AttrType>("scale_a"));
auto scale_b = static_cast<T>(context.Attr<AttrType>("scale_b"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = scale_b * (scale_a * x).tanh();
}
};
template <typename Place, typename T, typename AttrType = T>
class STanhGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto scale_a = static_cast<T>(context.Attr<AttrType>("scale_a"));
auto scale_b = static_cast<T>(context.Attr<AttrType>("scale_b"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
auto temp = (scale_a * x).tanh() * (scale_a * x).tanh();
dx.device(place) = dy * scale_a * scale_b * (static_cast<T>(1) - temp);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,46 +12,65 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/clip_op.h"
namespace paddle {
namespace operators {
class SigmoidOp : public framework::OperatorWithKernel {
using framework::LoDTensor;
class ClipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of SigmoidOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
"Output(Y) of SigmoidOp should not be null.");
ctx.Output<framework::LoDTensor>("Y")->Resize(
ctx.Input<Tensor>("X")->dims());
"Input(X) of ClipOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of ClipOp should not be null.");
auto x_dims = ctx.Input<LoDTensor>("X")->dims();
auto max = Attr<float>("max");
auto min = Attr<float>("min");
PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
ctx.Output<LoDTensor>("Out")->Resize(x_dims);
}
};
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "sigmoid input");
AddOutput("Y", "sigmoid output");
AddComment("Sigmoid function");
AddInput("X",
"(Tensor)The input of clip op."
"The input should be a k-D tensor(k > 0 and k < 7)");
AddOutput("Out", "(Tensor)The output of clip op with shape as input(X)");
AddAttr<AttrType>(
"min", "(float)Minimum value, under which element is replaced by min.");
AddAttr<AttrType>(
"max", "(float)Maximum value, above which element is replaced by max");
AddComment(R"DOC(
Clip operator limits the given input within an interval. The interval is
specified with arguments 'min' and 'max'.
)DOC");
}
};
class SigmoidOpGrad : public framework::OperatorWithKernel {
class ClipOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims());
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<LoDTensor>("X")->dims();
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad != nullptr) {
x_grad->Resize(x_dims);
}
}
};
......@@ -59,9 +78,9 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
ops::ClipOpGrad);
REGISTER_OP_CPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/clip_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::GPUPlace, float>);
......@@ -13,48 +13,83 @@
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/transform.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
using framework::Tensor;
using platform::Transform;
template <typename Place, typename T>
class SigmoidKernel : public framework::OpKernel {
template <typename T>
class ClipFunctor {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>("X");
auto output = context.Output<Tensor>("Y");
output->mutable_data<T>(context.GetPlace());
explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x) const {
if (x < min_)
return min_;
else if (x > max_)
return max_;
else
return x;
}
// The clipping is used in Paddle's raw implenmention
auto X = EigenVector<T>::Flatten(*input);
auto Y = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
private:
T min_;
T max_;
};
Y.device(place) = 1. / (1. + (-X).exp());
template <typename T>
class ClipGradFunctor {
public:
explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x, const T& y) const {
return (y > min_ && y < max_) ? x : 0;
}
private:
T min_;
T max_;
};
template <typename Place, typename T>
class SigmoidGradKernel : public framework::OpKernel {
class ClipKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto Y_t = context.Input<Tensor>("Y");
auto dY_t = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX_t = context.Output<Tensor>(framework::GradVarName("X"));
dX_t->mutable_data<T>(context.GetPlace());
auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int64_t numel = x->numel();
Transform<Place> trans;
trans(context.device_context(), x_data, x_data + numel, out_data,
ClipFunctor<T>(min, max));
}
};
auto dX = EigenVector<T>::Flatten(*dX_t);
auto Y = EigenVector<T>::Flatten(*Y_t);
auto dY = EigenVector<T>::Flatten(*dY_t);
dX.device(context.GetEigenDevice<Place>()) = dY * Y * (1. - Y);
template <typename Place, typename T>
class ClipGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* x = context.Input<Tensor>("X");
int64_t numel = d_out->numel();
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>();
const T* x_data = x->data<T>();
Transform<Place> trans;
trans(context.device_context(), d_out_data, d_out_data + numel, x_data,
d_x_data, ClipGradFunctor<T>(min, max));
}
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gemm_conv2d_op.h"
namespace paddle {
namespace operators {
int outputSize(int input_size, int filter_size, int padding, int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
class Conv2DOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"),
"Input(Input) of Conv2DOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"),
"Input(Filter) of Conv2DOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"),
"Output(Output) of Conv2DOp should not be null.");
auto in = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto out = ctx.Output<framework::LoDTensor>("Output");
std::vector<int> strides = Attr<std::vector<int>>("strides");
std::vector<int> paddings = Attr<std::vector<int>>("paddings");
int groups = Attr<int>("groups");
int input_channels = in->dims()[1];
int output_channels = filter->dims()[0];
PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D.");
PADDLE_ENFORCE_EQ(filter->dims().size(), 4,
"Conv2DOp filter should be 4-D.");
PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
PADDLE_ENFORCE_EQ(
output_channels % groups, 0,
"The number of output channels should be divided by groups.");
auto output_height =
outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]);
auto output_width =
outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]);
out->Resize(
{in->dims()[0], filter->dims()[0], output_height, output_width});
}
};
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
"The input tensor of convolution operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of image.");
AddInput(
"Filter",
"The filter tensor of convolution operator."
"The format of the filter tensor is MCHW, where M is the number of "
"output image channels, C is the number of input image channels, "
"H and W is height and width of filter. "
"If the groups attribute is greater than 1, C equal the number of "
"input image channels divided by the groups.");
AddOutput("Output",
"The output tensor of convolution operator."
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides", "strides of convolution operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.")
.SetDefault({0, 0});
AddAttr<int>(
"groups",
"group size of convolution operator. "
"Refer to grouped convolution in Alex Krizhevsky's paper: "
"when group=2, the first half of the filters are only connected to the "
"first half of the input channels, and the second half only connected "
"to the second half.")
.SetDefault(1);
AddComment(R"DOC(
The convolution operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
)DOC");
}
};
class Conv2DOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto in = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto d_in =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
auto d_filter =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Filter"));
if (d_in) d_in->Resize(in->dims());
if (d_filter) d_filter->Resize(filter->dims());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad,
ops::Conv2DOpGrad);
REGISTER_OP_CPU_KERNEL(
conv2d, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gemm_conv2d_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
conv2d, ops::GemmConv2DKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/crop_op.h"
#include <boost/lexical_cast.hpp>
namespace paddle {
namespace operators {
using framework::Tensor;
using framework::LoDTensor;
class CropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of CropOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of CropOp should not be null.");
auto x_dim = ctx.Input<LoDTensor>("X")->dims();
auto *y = ctx.Input<LoDTensor>("Y");
auto *out = ctx.Output<LoDTensor>("Out");
if (y == nullptr) {
auto shape = Attr<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ(
int64_t(shape.size()), x_dim.size(),
"Shape size should be equal to dimention size of input tensor.");
std::vector<int64_t> tensor_shape(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]);
}
out->Resize(framework::make_ddim(tensor_shape));
} else {
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()),
"Tensor rank of both CropOp's "
"inputs must be same.");
out->Resize(y->dims());
}
}
};
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7)");
AddInput("Y",
"The input used as reference for cropping"
" with the same dimension as X. ");
AddOutput("Out",
"The output of crop op "
"with the same dimension as X.");
AddAttr<std::vector<int>>("offsets",
"A list<int> describing offsets to be cropped."
"The size of offsets list should be as same as "
"dimension size of input X.");
AddAttr<std::vector<int>>("shape",
"A list<int> describing the shape of output."
"The size of shape list should be as same as "
"dimension size of input X.")
.SetDefault(std::vector<int>());
AddComment(R"DOC(
Crop Operator.
Crop input into output, as specified by offsets and shape.
There are two ways to set shape:
1. referenc input: crop input X as shape as reference input.
The dimension of reference input should
be as same as input X.
2. shape list: crop input X by shape described by a list<int>.
The size of shape list should be as same as
dimension size of input X.
The input should be a k-D tensor(k > 0 and k < 7). As an example:
Given:
X = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]]
and
offsets = [0, 1]
and
shape = [2, 2]
then we get
Out = [[1, 2],
[3, 4]]
)DOC");
}
};
class CropOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<LoDTensor>("X")->dims();
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad != nullptr) {
x_grad->Resize(x_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(crop, ops::CropOp, ops::CropOpMaker, crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_CPU_KERNEL(crop_grad,
ops::CropGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/crop_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_GPU_KERNEL(crop_grad,
ops::CropGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 CropdleCropdle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators { // Internal
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor;
template <typename T>
class CropKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto x_stride = framework::stride(x->dims());
auto out_stride = framework::stride(out->dims());
auto offsets = context.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
x->dims().size(), offsets.size(),
"Offsets size should be equal to dimension size of input tensor.");
int64_t offset = 0;
for (int i = 0; i < offsets.size(); ++i) {
offset += (x_stride[i] * offsets[i]);
}
StridedMemcpy<T>(context.device_context(), x_data + offset, x_stride,
out->dims(), out_stride, out_data);
}
};
template <typename Place, typename T, size_t D>
void CropGradFunction(const framework::ExecutionContext& context) {
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
d_x->mutable_data<T>(context.GetPlace());
auto offsets = context.Attr<std::vector<int>>("offsets");
Eigen::array<std::pair<int, int>, D> paddings;
for (int i = 0; i < D; ++i) {
paddings[i].first = offsets[i];
paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i];
}
auto d_x_tensor = EigenTensor<T, D>::From(*d_x);
auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
d_x_tensor.device(context.GetEigenDevice<Place>()) =
d_out_tensor.pad(paddings, 0);
}
}
template <typename Place, typename T>
class CropGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
size_t rank =
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
switch (rank) {
case 1:
CropGradFunction<Place, T, 1>(context);
break;
case 2:
CropGradFunction<Place, T, 2>(context);
break;
case 3:
CropGradFunction<Place, T, 3>(context);
break;
case 4:
CropGradFunction<Place, T, 4>(context);
break;
case 5:
CropGradFunction<Place, T, 5>(context);
break;
case 6:
CropGradFunction<Place, T, 6>(context);
break;
default:
PADDLE_THROW(
"CropOp only support tensors with no more than 6 dimensions.");
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/ddim.h"
#include "paddle/memory/memcpy.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace detail {
template <typename T, int Rank>
struct StridedMemcpyFunctor;
template <typename T>
struct StridedMemcpyFunctor<T, 1> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<1> src_stride, framework::Dim<1> dst_dim,
framework::Dim<1> dst_stride, T* dst) const {
auto place = dev_ctx.GetPlace();
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim.head);
} else {
#ifndef PADDLE_ONLY_CPU
auto& gpu_place = boost::get<platform::GPUPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T) * dst_dim.head,
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
};
template <typename T, int Rank>
struct StridedMemcpyFunctor {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<Rank> src_stride, framework::Dim<Rank> dst_dim,
framework::Dim<Rank> dst_stride, T* dst) const {
for (int64_t i = 0; i < dst_dim.head; ++i) {
StridedMemcpyFunctor<T, Rank - 1> func;
func(dev_ctx, src, src_stride.tail, dst_dim.tail, dst_stride.tail, dst);
src += src_stride.head;
dst += dst_stride.head;
}
}
};
template <typename T>
struct StridedCopyDimVisitor : public boost::static_visitor<void> {
StridedCopyDimVisitor(const platform::DeviceContext& dev_ctx, const T* src,
const framework::DDim& src_stride,
const framework::DDim& dst_stride, T* dst)
: dev_ctx_(dev_ctx),
src_(src),
src_stride_(src_stride),
dst_stride_(dst_stride),
dst_(dst) {}
template <typename Dim>
void operator()(Dim dst_dim) const {
Dim src_stride = boost::get<Dim>(src_stride_);
Dim dst_stride = boost::get<Dim>(dst_stride_);
constexpr int dim = Dim::dimensions;
StridedMemcpyFunctor<T, dim> functor;
functor(dev_ctx_, src_, src_stride, dst_dim, dst_stride, dst_);
}
const platform::DeviceContext& dev_ctx_;
const T* src_;
const framework::DDim& src_stride_;
const framework::DDim& dst_stride_;
T* dst_;
};
} // namespace detail
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class GemmConv2DKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
// The filter will be reshaped in the calculations,
// so here use an assignment operation,
// that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups");
int batch_size = input->dims()[0];
int input_channels = input->dims()[1];
int filter_height = filter.dims()[filter.dims().size() - 2];
int filter_width = filter.dims()[filter.dims().size() - 1];
int output_channels = output->dims()[1];
int output_height = output->dims()[2];
int output_width = output->dims()[3];
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
im2col;
// use col_shape in the im2col calculation
framework::DDim col_shape = {input_channels / groups, filter_height,
filter_width, output_height, output_width};
// use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = {
input_channels / groups * filter_height * filter_width,
output_height * output_width};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
framework::DDim input_shape = {input->dims()[1], input->dims()[2],
input->dims()[3]};
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {output_channels,
output_height * output_width};
// convolution operator: im2col + gemm
int in_step = input_channels / groups;
int out_step = output_channels / groups;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice<T>(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
// im2col
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
im2col(context.device_context(), in_slice, col, strides[0], strides[1],
paddings[0], paddings[1]);
// gemm
Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, false,
col_matrix, false, T(1.0), &out_slice, T(0.0));
}
}
}
};
template <typename Place, typename T>
class GemmConvGrad2DKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
// The filter and filter_grad will be reshaped in the calculations,
// so here use an assignment operation,
// that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups");
int batch_size = input->dims()[0];
int input_channels = input->dims()[1];
int filter_height = filter.dims()[filter.dims().size() - 2];
int filter_width = filter.dims()[filter.dims().size() - 1];
int output_channels = output_grad->dims()[1];
int output_height = output_grad->dims()[2];
int output_width = output_grad->dims()[3];
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
im2col;
// use col_shape in the im2col and col2im calculation
framework::DDim col_shape = {input_channels / groups, filter_height,
filter_width, output_height, output_width};
// use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = {
input_channels / groups * filter_height * filter_width,
output_height * output_width};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
framework::DDim input_shape = {input->dims()[1], input->dims()[2],
input->dims()[3]};
framework::DDim output_matrix_shape = {
output_grad->dims()[1],
output_grad->dims()[2] * output_grad->dims()[3]};
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
// convolution backward input operator: gemm + col2im
// convolution backward weight operator: im2col + gemm
int in_step = input_channels / groups;
int out_step = output_channels / groups;
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice<T>(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch =
input_grad->Slice<T>(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor out_grad_slice =
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor filter_slice =
filter.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
// col2im
Tensor in_grad_slice =
in_grad_batch.Slice<T>(g * in_step, (g + 1) * in_step);
col2im(context.device_context(), in_grad_slice, col, strides[0],
strides[1], paddings[0], paddings[1]);
}
}
}
if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
auto t = framework::EigenVector<T>::Flatten(filter_grad_);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice<T>(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// im2col
Tensor out_grad_slice =
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
im2col(context.device_context(), in_slice, col, strides[0],
strides[1], paddings[0], paddings[1]);
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), out_grad_slice,
false, col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -27,9 +27,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -79,9 +80,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -137,9 +138,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -197,9 +199,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......
......@@ -64,9 +64,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -84,9 +85,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int block_y = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1);
dim3 grid(block_x, block_y);
im2col<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
im2col<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height,
filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, col.data<T>());
......@@ -149,9 +150,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -174,9 +175,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
col2im<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
num_kernels, col.data<T>(), input_height + 2 * padding_height,
input_width + 2 * padding_width, input_channels, filter_height,
filter_width, stride_height, stride_width, padding_height,
......@@ -235,9 +236,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -268,9 +270,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height);
im2colOCF<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
im2colOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
......@@ -318,9 +320,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -351,9 +353,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height);
col2imOCF<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
col2imOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
......
......@@ -72,17 +72,18 @@ enum class ColFormat { kCFO = 0, kOCF = 1 };
template <ColFormat Format, typename Place, typename T>
class Im2ColFunctor {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context);
int padding_width);
};
template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context);
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width);
};
} // namespace math
......
......@@ -78,8 +78,8 @@ void testIm2col() {
PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU
}
im2col(input, output_cfo, stride, stride, padding, padding, context);
im2col_ocf(input, output_ocf, stride, stride, padding, padding, context);
im2col(*context, input, output_cfo, stride, stride, padding, padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding);
float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) {
......
......@@ -96,7 +96,7 @@ class PReluGradKernel : public framework::OpKernel {
trans(context.device_context(), out_ptr, out_ptr + numel, dout_ptr, dx_ptr,
PReluGradFunctor<T>(alpha_ptr));
// TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready
// TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/rank_loss_op.h"
namespace paddle {
namespace operators {
class RankLossOp : public framework::OperatorWithKernel {
public:
RankLossOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
// input check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) shouldn't be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"),
"Input(Left) shouldn't be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"),
"Input(Right) shouldn't be null");
auto label_dims = ctx.Input<framework::Tensor>("Label")->dims();
auto left_dims = ctx.Input<framework::Tensor>("Left")->dims();
auto right_dims = ctx.Input<framework::Tensor>("Right")->dims();
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
"All inputs must have the same size");
PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1),
"All inputs must be row vector with size batch_size x 1.");
ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims);
}
};
class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RankLossOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Label",
"The label indicating A ranked higher than B or not, row vector.");
AddInput("Left", "The output of RankNet for doc A, vector.");
AddInput("Right", "The output of RankNet for doc B, vetor");
AddOutput("Out", "The output loss of RankLoss operator, vector.");
AddComment(R"DOC(RankLoss operator
Rank loss operator for RankNet[1]. RankNet is a pairwise ranking model with
one training sample consisting of a pair of doc A and B, and the label P
indicating that A is ranked higher than B or not:
P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of
the input pair.
The RankLoss operator contains three inputs: Left (o_i), Right (o_j) and Label
(P_{i,j}), which represent the output of RankNet for two docs and the label
respectively, and yields the rank loss C_{i,j} by following the expression
\f[
C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\
o_{i,j} = o_i - o_j \\
\tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \}
\f]
The operator can take inputs of one sample or in batch.
[1]. Chris Burges, Tal Shaked, Erin Renshaw, et al. Learning to
Rank using Gradient Descent.
http://icml.cc/2015/wp-content/uploads/2015/06/icml_ranking.pdf
)DOC");
}
};
class RankLossGradOp : public framework::OperatorWithKernel {
public:
RankLossGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"),
"Input(Left) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"),
"Input(Right) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto dims = ctx.Input<framework::Tensor>("Left")->dims();
auto *left_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Left"));
auto *right_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Right"));
if (left_grad) {
left_grad->Resize(dims);
}
if (right_grad) {
right_grad->Resize(dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(rank_loss, ops::RankLossOp, ops::RankLossOpMaker, rank_loss_grad,
ops::RankLossGradOp);
REGISTER_OP_CPU_KERNEL(rank_loss,
ops::RankLossKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
rank_loss_grad, ops::RankLossGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/rank_loss_op.h"
REGISTER_OP_GPU_KERNEL(
rank_loss,
paddle::operators::RankLossKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
rank_loss_grad,
paddle::operators::RankLossGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class RankLossKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
auto* label_t = ctx.Input<framework::Tensor>("Label");
auto* left_t = ctx.Input<framework::Tensor>("Left");
auto* right_t = ctx.Input<framework::Tensor>("Right");
out_t->mutable_data<T>(ctx.GetPlace());
auto out = framework::EigenVector<T>::Flatten(*out_t);
auto label = framework::EigenVector<T>::Flatten(*label_t);
auto left = framework::EigenVector<T>::Flatten(*left_t);
auto right = framework::EigenVector<T>::Flatten(*right_t);
auto& dev = ctx.GetEigenDevice<Place>();
out.device(dev) =
(1. + (left - right).exp()).log() - label * (left - right);
}
};
template <typename Place, typename T>
class RankLossGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_left_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Left"));
auto* d_right_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Right"));
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* label_t = ctx.Input<framework::Tensor>("Label");
auto* left_t = ctx.Input<framework::Tensor>("Left");
auto* right_t = ctx.Input<framework::Tensor>("Right");
auto& dev = ctx.GetEigenDevice<Place>();
auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
auto label = framework::EigenVector<T>::Flatten(*label_t);
auto left = framework::EigenVector<T>::Flatten(*left_t);
auto right = framework::EigenVector<T>::Flatten(*right_t);
// compute d_left
if (d_left_t) {
d_left_t->mutable_data<T>(ctx.GetPlace());
auto d_left = framework::EigenVector<T>::Flatten(*d_left_t);
d_left.device(dev) = d_out * (1. / (1. + (right - left).exp()) - label);
}
// compute d_right
if (d_right_t) {
d_right_t->mutable_data<T>(ctx.GetPlace());
auto d_right = framework::EigenVector<T>::Flatten(*d_right_t);
d_right.device(dev) =
-d_out * (1.0 / (1. + (right - left).exp()) - label);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -29,9 +29,11 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<LoDTensor>()
->dims()[0];
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
seq_len_ = input0->GetMutable<LoDTensor>()->dims()[0];
PADDLE_ENFORCE_GT(seq_len_, 0);
CreateScopes(scope);
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
......@@ -123,14 +125,12 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
}
const rnn::ArgumentName RecurrentOp::kArgName{
"step_net", "step_scopes", "inlinks",
"outlinks", "inlink_alias", "outlink_alias",
"step_net", "step_scopes", "inlinks", "outlinks",
"memories", "pre_memories", "boot_memories"};
const rnn::ArgumentName RecurrentGradientOp::kArgName{
"step_net", "step_scopes", "outlink@grad",
"inlink@grad", "inlink_alias", "outlink_alias",
"memories", "pre_memories", "boot_memories@grad"};
"step_net", "step_scopes", "outlink@grad", "inlink@grad",
"memories", "pre_memories", "boot_memories@grad"};
RecurrentOp::RecurrentOp(const std::string& type,
const framework::VariableNameMap& inputs,
......@@ -160,8 +160,6 @@ class RecurrentAlgorithmProtoAndCheckerMaker
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
AddAttr<std::vector<std::string>>(name.inlink_alias, "alias of inlinks");
AddAttr<std::vector<std::string>>(name.outlink_alias, "alias of outlinks");
AddAttr<std::vector<std::string>>(name.pre_memories,
"names of pre-memories");
AddAttr<std::vector<std::string>>(name.memories, "names of memories");
......@@ -206,9 +204,8 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
}
void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<LoDTensor>()
->dims()[0];
seq_len_ =
scope.FindVar(arg_->inlinks[0])->GetMutable<LoDTensor>()->dims()[0];
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
......
......@@ -24,22 +24,23 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const size_t seq_len,
bool infer_shape_mode) {
const std::vector<std::string>& inlinks,
const size_t seq_len, bool infer_shape_mode) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) {
auto input_var = step_scopes[0]->FindVar(inlinks[i].external);
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
inlinks[i].external);
// global inputs
auto input_var = step_scopes[0]->parent().FindVar(inlinks[i]);
PADDLE_ENFORCE_NOT_NULL(input_var, "input link [%s] is not in scope.",
inlinks[i]);
LoDTensor* input = input_var->GetMutable<LoDTensor>();
f::DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length");
PADDLE_ENFORCE_EQ(static_cast<size_t>(dims[0]), seq_len,
"all the inlinks be the same length");
f::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
step_scopes[j]->NewVar(inlinks[i])->GetMutable<Tensor>();
if (!infer_shape_mode) {
// The input of operators of each step is Tensor here.
// Maybe need to modify Slice function.
......@@ -51,18 +52,17 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
}
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks, const size_t seq_len,
bool infer_shape_mode) {
const std::vector<std::string>& outlinks,
const size_t seq_len, bool infer_shape_mode) {
for (size_t i = 0; i < outlinks.size(); i++) {
auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
outlinks[i].external);
auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
outlinks[i]);
LoDTensor* output = output_var->GetMutable<LoDTensor>();
if (infer_shape_mode) {
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
outlinks[i].internal);
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims);
......@@ -71,9 +71,8 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
} else {
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
LoDTensor* step_output = step_scopes[j]
->FindVar(outlinks[i].internal)
->GetMutable<LoDTensor>();
LoDTensor* step_output =
step_scopes[j]->FindVar(outlinks[i])->GetMutable<LoDTensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
......@@ -113,29 +112,9 @@ void InitArgument(const ArgumentName& name, Argument* arg,
const framework::OperatorBase& op) {
arg->step_scopes = op.Output(name.step_scopes);
auto inlinks = op.Inputs(name.inlinks);
auto inlink_alias = op.Attr<std::vector<std::string>>(name.inlink_alias);
PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(),
"the size of inlinks and inlink_alias don't match:%d,%d",
inlinks.size(), inlink_alias.size());
for (size_t i = 0; i < inlinks.size(); ++i) {
rnn::Link link;
link.external = inlinks[i];
link.internal = inlink_alias[i];
(arg->inlinks).push_back(link);
}
arg->inlinks = op.Inputs(name.inlinks);
auto outlinks = op.Outputs(name.outlinks);
auto outlink_alias = op.Attr<std::vector<std::string>>(name.outlink_alias);
PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(),
"the size of outlinks and outlink_alias don't match:%d,%d",
outlinks.size(), outlink_alias.size());
for (size_t i = 0; i < outlinks.size(); ++i) {
rnn::Link link;
link.external = outlinks[i];
link.internal = outlink_alias[i];
(arg->outlinks).push_back(link);
}
arg->outlinks = op.Outputs(name.outlinks);
auto boot_memories = op.Inputs(name.boot_memories);
......
......@@ -41,18 +41,11 @@ struct MemoryAttr {
std::string boot_var;
};
struct Link {
// input or output links name.
std::string internal;
// alias to avoid duplicate keys in scopes.
std::string external;
};
struct Argument {
std::string step_net;
std::string step_scopes;
std::vector<Link> inlinks;
std::vector<Link> outlinks;
std::vector<std::string> inlinks;
std::vector<std::string> outlinks;
std::vector<rnn::MemoryAttr> memories;
};
......@@ -61,8 +54,6 @@ struct ArgumentName {
std::string step_scopes;
std::string inlinks;
std::string outlinks;
std::string inlink_alias; // the alias of inlinks in step net.
std::string outlink_alias; // the alias of outlinks in step net.
std::string memories; // the memory name
std::string pre_memories; // the previous memory name
std::string boot_memories; // the boot memory name
......@@ -72,15 +63,15 @@ struct ArgumentName {
* Prepare inputs for each step net.
*/
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const size_t seq_len,
bool infer_shape_mode);
const std::vector<std::string>& inlinks,
const size_t seq_len, bool infer_shape_mode);
/**
* Process outputs of step nets and merge to variables.
*/
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks, const size_t seq_len,
bool infer_shape_mode);
const std::vector<std::string>& outlinks,
const size_t seq_len, bool infer_shape_mode);
void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const size_t step_id,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/detail/strided_memcpy.h"
namespace paddle {
namespace operators {
// Strided memory copy from src to dst.
//
// The src and dst should be both on dev_ctx.GetPlace(), otherwise, there will
// be a segment fault.
//
// The stride of an array (also referred to as increment, pitch or step size) is
// the number of locations in memory between beginnings of successive array
// elements
//
// For example, for tensor like [1, 3, 300, 300]. If there is no padding, the
// stride is [270000, 90000, 300, 1].
//
// NOTE: When use GPU, the memcpy is async. To sync memcpy, please invoke
// `dev_ctx.Wait()`.
template <typename T>
inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
const framework::DDim& src_stride,
const framework::DDim& dst_dim,
const framework::DDim& dst_stride, T* dst) {
using namespace detail;
StridedCopyDimVisitor<T> func(dev_ctx, src, src_stride, dst_stride, dst);
boost::apply_visitor(func, dst_dim);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/strided_memcpy.h"
#include "gtest/gtest.h"
#include "paddle/memory/memory.h"
namespace paddle {
namespace operators {
TEST(StridedMemcpy, CPUCrop) {
// clang-format off
int src[] = {
0, 1, 2, 0, 0,
0, 3, 4, 0, 0,
0, 0, 0, 0, 0,
};
// clang-format on
framework::DDim src_stride({5, 1});
int dst[4];
framework::DDim dst_dim({2, 2});
framework::DDim dst_stride({2, 1});
platform::CPUDeviceContext ctx;
StridedMemcpy<int>(ctx, src + 1, src_stride, dst_dim, dst_stride, dst);
ASSERT_EQ(1, dst[0]);
ASSERT_EQ(2, dst[1]);
ASSERT_EQ(3, dst[2]);
ASSERT_EQ(4, dst[3]);
}
TEST(StridedMemcpy, CPUConcat) {
// clang-format off
int src[] = {
1, 2,
3, 4
};
// clang-format on
int dst[8];
framework::DDim src_stride({2, 1});
framework::DDim dst_dim({2, 2});
framework::DDim dst_stride({4, 1});
platform::CPUDeviceContext ctx;
StridedMemcpy<int>(ctx, src, src_stride, dst_dim, dst_stride, dst);
StridedMemcpy<int>(ctx, src, src_stride, dst_dim, dst_stride, dst + 2);
// clang-format off
int expect_dst[] = {
1, 2, 1, 2,
3, 4, 3, 4
};
// clang-format on
for (size_t i = 0; i < sizeof(expect_dst) / sizeof(int); ++i) {
ASSERT_EQ(expect_dst[i], dst[i]);
}
}
#ifndef PADDLE_ONLY_CPU
TEST(StridedMemcpy, GPUCrop) {
// clang-format off
int src[] = {
0, 1, 2, 0, 0,
0, 3, 4, 0, 0,
0, 0, 0, 0, 0,
};
// clang-format on
platform::GPUPlace gpu0(0);
platform::CPUPlace cpu;
int* gpu_src = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(src)));
memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src));
framework::DDim src_stride({5, 1});
int dst[4];
int* gpu_dst = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(dst)));
framework::DDim dst_dim({2, 2});
framework::DDim dst_stride({2, 1});
platform::CUDADeviceContext ctx(gpu0);
StridedMemcpy<int>(ctx, gpu_src + 1, src_stride, dst_dim, dst_stride,
gpu_dst);
memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream());
ctx.Wait();
ASSERT_EQ(1, dst[0]);
ASSERT_EQ(2, dst[1]);
ASSERT_EQ(3, dst[2]);
ASSERT_EQ(4, dst[3]);
memory::Free(gpu0, gpu_dst);
memory::Free(gpu0, gpu_src);
}
TEST(StridedMemcpy, GPUConcat) {
// clang-format off
int src[] = {
1, 2,
3, 4
};
// clang-format on
platform::GPUPlace gpu0(0);
platform::CPUPlace cpu;
int* gpu_src = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(src)));
memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src));
int dst[8];
int* gpu_dst = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(dst)));
framework::DDim src_stride({2, 1});
framework::DDim dst_dim({2, 2});
framework::DDim dst_stride({4, 1});
platform::CUDADeviceContext ctx(gpu0);
StridedMemcpy<int>(ctx, gpu_src, src_stride, dst_dim, dst_stride, gpu_dst);
StridedMemcpy<int>(ctx, gpu_src, src_stride, dst_dim, dst_stride,
gpu_dst + 2);
memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream());
ctx.Wait();
// clang-format off
int expect_dst[] = {
1, 2, 1, 2,
3, 4, 3, 4
};
// clang-format on
for (size_t i = 0; i < sizeof(expect_dst) / sizeof(int); ++i) {
ASSERT_EQ(expect_dst[i], dst[i]);
}
memory::Free(gpu0, gpu_dst);
memory::Free(gpu0, gpu_src);
}
#endif
} // namespace operators
} // namespace paddle
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/transpose_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class TransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
size_t x_rank = x_dims.size();
size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(x_rank, axis_size,
"the input tensor's rank(%d) "
"should be equal to the axis's size(%d)",
x_rank, axis_size);
std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) {
PADDLE_ENFORCE(
axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1,
"Each element of Attribute axis should be a unique value "
"range from 0 to (dims - 1), "
"where the dims is the axis's size");
}
framework::DDim out_dims(x_dims);
for (size_t i = 0; i < axis_size; i++) {
out_dims[i] = x_dims[axis[i]];
}
ctx.Output<framework::LoDTensor>("Out")->Resize(out_dims);
}
};
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TransposeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor)The input tensor, tensors with rank at most 6 are supported");
AddOutput("Out", "(Tensor)The output tensor");
AddAttr<std::vector<int>>(
"axis",
"(vector<int>)a list of values, and the size of the list should be "
"the same with the input tensor rank, the tensor will "
"permute the axes according the the values given");
AddComment(R"DOC(
The Tensor will be permuted according to the axis values given.
The op is very much like the numpy.transpose function in python
For example:
>> input = numpy.arange(6).reshape((2,3))
>> input
array([[0, 1, 2],
[3, 4, 5]])
>> axis = [1, 0]
>> output = input.transpose(axis)
>> output
array([[0, 3],
[1, 4],
[2, 5]])
So, given a input tensor of shape(N, C, H, W) and the axis is {0, 2, 3, 1},
the output tensor shape will be (N, H, W, C)
)DOC");
}
};
class TransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (x_grad) x_grad->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(transpose, ops::TransposeOp, ops::TransposeOpMaker, transpose_grad,
ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL(transpose,
ops::TransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUPlace, float>);
......@@ -12,12 +12,11 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/transpose_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(transpose,
ops::TransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::GPUPlace, float>);
transpose_grad,
ops::TransposeGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T, int Rank>
void EigenTranspose(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor& out,
std::vector<int> axis) {
Eigen::array<int, Rank> permute;
for (int i = 0; i < Rank; i++) {
permute[i] = axis[i];
}
auto in_dim = in.dims();
auto out_dim = out.dims();
auto eigen_in = framework::EigenTensor<T, Rank>::From(in);
auto eigen_out = framework::EigenTensor<T, Rank>::From(out);
auto& dev = context.GetEigenDevice<Place>();
eigen_out.device(dev) = eigen_in.shuffle(permute);
}
template <typename Place, typename T>
class TransposeKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
switch (ndims) {
case 1:
EigenTranspose<Place, T, 1>(context, *x, *out, axis);
break;
case 2:
EigenTranspose<Place, T, 2>(context, *x, *out, axis);
break;
case 3:
EigenTranspose<Place, T, 3>(context, *x, *out, axis);
break;
case 4:
EigenTranspose<Place, T, 4>(context, *x, *out, axis);
break;
case 5:
EigenTranspose<Place, T, 5>(context, *x, *out, axis);
break;
case 6:
EigenTranspose<Place, T, 6>(context, *x, *out, axis);
break;
default:
PADDLE_THROW("Tensors with rank at most 6 are supported");
}
}
};
template <typename Place, typename T>
class TransposeGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
if (x_grad) {
x_grad->mutable_data<T>(context.GetPlace());
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
int ndims = axis.size();
switch (ndims) {
case 1:
EigenTranspose<Place, T, 1>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 2:
EigenTranspose<Place, T, 2>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 3:
EigenTranspose<Place, T, 3>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 4:
EigenTranspose<Place, T, 4>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 5:
EigenTranspose<Place, T, 5>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 6:
EigenTranspose<Place, T, 6>(context, *out_grad, *x_grad,
reversed_axis);
break;
default:
PADDLE_THROW("Tensors with rank at most 6 are supported");
}
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestExp(OpTest):
def setUp(self):
self.op_type = "exp"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.exp(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestSigmoid(OpTest):
def setUp(self):
self.op_type = "sigmoid"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
class TestTanh(OpTest):
def setUp(self):
self.op_type = "tanh"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.tanh(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestSqrt(OpTest):
def setUp(self):
self.op_type = "sqrt"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.sqrt(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestAbs(OpTest):
def setUp(self):
self.op_type = "abs"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
# Because we set delta = 0.005 in caculating numeric gradient,
# if x is too small, such as 0.002, x_neg will be -0.003
# x_pos will be 0.007, so the numeric gradient is unaccurate.
# we should avoid this
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Y': np.abs(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestRelu(OpTest):
def setUp(self):
self.op_type = "relu"
x = np.random.uniform(-1, 1, [11, 17]).astype("float32")
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Y': np.maximum(self.inputs['X'], 0)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestBRelu(OpTest):
def setUp(self):
self.op_type = "brelu"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
t_min = 1
t_max = 4
# The same with TestAbs
x[np.abs(x - t_min) < 0.005] = t_min + 0.02
x[np.abs(x - t_max) < 0.005] = t_max + 0.02
self.inputs = {'X': x}
self.attrs = {'t_min': t_min, 't_max': t_max}
t = np.copy(x)
t[t < t_min] = t_min
t[t > t_max] = t_max
self.outputs = {'Y': t}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestSoftRelu(OpTest):
def setUp(self):
self.op_type = "soft_relu"
x = np.random.uniform(-3, 3, [4, 4]).astype("float32")
threshold = 2
# The same reason with TestAbs
x[np.abs(x - threshold) < 0.005] = threshold + 0.02
x[np.abs(x + threshold) < 0.005] = -threshold + 0.02
self.inputs = {'X': x}
self.attrs = {'threshold': threshold}
t = np.copy(x)
t[t < -threshold] = -threshold
t[t > threshold] = threshold
self.outputs = {'Y': np.log((np.exp(t) + 1))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestReciprocal(OpTest):
def setUp(self):
self.op_type = "reciprocal"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.outputs = {'Y': np.reciprocal(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.01)
class TestLog(OpTest):
def setUp(self):
self.op_type = "log"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.log(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestSquare(OpTest):
def setUp(self):
self.op_type = "square"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.square(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestPow(OpTest):
def setUp(self):
self.op_type = "pow"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.attrs = {'factor': 3}
self.outputs = {'Y': np.power(self.inputs['X'], 3)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestSTanh(OpTest):
def setUp(self):
self.op_type = "stanh"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
scale_a = 2.0 / 3.0
scale_b = 1.7159
self.attrs = {'scale_a': scale_a, 'scale_b': scale_b}
self.outputs = {'Y': scale_b * np.tanh(self.inputs['X'] * scale_a)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestClipOp(OpTest):
def setUp(self):
self.max_relative_error = 0.006
self.initTestCase()
input = np.random.random(self.shape).astype("float32")
input[np.abs(input - self.min) < self.max_relative_error] = 0.5
input[np.abs(input - self.max) < self.max_relative_error] = 0.5
self.op_type = "clip"
self.inputs = {'X': input, }
self.attrs = {}
self.attrs['min'] = self.min
self.attrs['max'] = self.max
self.outputs = {
'Out': np.clip(self.inputs['X'], self.attrs['min'],
self.attrs['max'])
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X'], 'Out', max_relative_error=self.max_relative_error)
def initTestCase(self):
self.shape = (4, 4)
self.max = 0.7
self.min = 0.1
class TestCase1(TestClipOp):
def initTestCase(self):
self.shape = (8, 16, 8)
self.max = 0.7
self.min = 0
class TestCase2(TestClipOp):
def initTestCase(self):
self.shape = (8, 16)
self.max = 1
self.min = 0
class TestCase3(TestClipOp):
def initTestCase(self):
self.shape = (4, 8, 16)
self.max = 0.7
self.min = 0.2
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestConv2dOp(OpTest):
def setUp(self):
self.init_groups()
self.op_type = "conv2d"
batch_size = 2
input_channels = 3
input_height = 5
input_width = 5
output_channels = 6
filter_height = 3
filter_width = 3
stride = 1
padding = 0
output_height = (input_height - filter_height + 2 * padding
) / stride + 1
output_width = (input_width - filter_width + 2 * padding) / stride + 1
input = np.random.random((batch_size, input_channels, input_height,
input_width)).astype("float32")
filter = np.random.random(
(output_channels, input_channels / self.groups, filter_height,
filter_width)).astype("float32")
output = np.ndarray(
(batch_size, output_channels, output_height, output_width))
self.inputs = {'Input': input, 'Filter': filter}
self.attrs = {
'strides': [1, 1],
'paddings': [0, 0],
'groups': self.groups
}
output_group_channels = output_channels / self.groups
input_group_channels = input_channels / self.groups
for batchid in xrange(batch_size):
for group in xrange(self.groups):
for outchannelid in range(group * output_group_channels,
(group + 1) * output_group_channels):
for rowid in xrange(output_height):
for colid in xrange(output_width):
start_h = (rowid * stride) - padding
start_w = (colid * stride) - padding
output_value = 0.0
for inchannelid in range(
group * input_group_channels,
(group + 1) * input_group_channels):
for frowid in xrange(filter_height):
for fcolid in xrange(filter_width):
input_value = 0.0
inrowid = start_h + frowid
incolid = start_w + fcolid
if ((inrowid >= 0 and
inrowid < input_height) and
(incolid >= 0 and
incolid < input_width)):
input_value = input[batchid][
inchannelid][inrowid][incolid]
filter_value = filter[outchannelid][
inchannelid % input_group_channels][
frowid][fcolid]
output_value += input_value * filter_value
output[batchid][outchannelid][rowid][
colid] = output_value
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad(
['Input'],
'Output',
max_relative_error=0.05,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.05,
no_grad_set=set(['Input']))
def init_groups(self):
self.groups = 1
class TestWithGroup(TestConv2dOp):
def init_groups(self):
self.groups = 3
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
def crop(data, offsets, crop_shape):
def indexOf(shape, index):
result = []
for dim in reversed(shape):
result.append(index % dim)
index = index / dim
return result[::-1]
result = []
for i, value in enumerate(data.flatten()):
index = indexOf(data.shape, i)
selected = True
if len(index) == len(offsets):
for j, offset in enumerate(offsets):
selected = selected and index[j] >= offset and index[
j] < crop_shape[j] + offset
if selected:
result.append(value)
return np.array(result).reshape(crop_shape)
class TestCropOp(OpTest):
def setUp(self):
self.op_type = "crop"
self.crop_by_input = False
self.attrs = {}
self.initTestCase()
self.attrs['offsets'] = self.offsets
if self.crop_by_input:
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'Y': np.random.random(self.crop_shape).astype("float32")
}
else:
self.attrs['shape'] = self.crop_shape
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
}
self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
}
def initTestCase(self):
self.x_shape = (8, 8)
self.crop_shape = (2, 2)
self.offsets = [1, 2]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.006)
class TestCase1(TestCropOp):
def initTestCase(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
class TestCase2(TestCropOp):
def initTestCase(self):
self.x_shape = (4, 8)
self.crop_shape = [4, 8]
self.offsets = [0, 0]
class TestCase3(TestCropOp):
def initTestCase(self):
self.x_shape = (4, 8, 16)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.crop_by_input = True
class TestCase4(TestCropOp):
def initTestCase(self):
self.x_shape = (4, 4)
self.crop_shape = [4, 4]
self.offsets = [0, 0]
self.crop_by_input = True
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestRankLossOp(OpTest):
def setUp(self):
self.op_type = "rank_loss"
batch_size = 5
# labels_{i} = {0, 1.0} or {0, 0.5, 1.0}
label = np.random.randint(0, 2, size=(batch_size, 1)).astype("float32")
left = np.random.random((batch_size, 1)).astype("float32")
right = np.random.random((batch_size, 1)).astype("float32")
loss = np.log(1.0 + np.exp(left - right)) - label * (left - right)
self.inputs = {'Label': label, 'Left': left, 'Right': right}
self.outputs = {'Out': loss}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["Left", "Right"], "Out")
def test_check_grad_ignore_left(self):
self.check_grad(["Right"], "Out", no_grad_set=set('Left'))
def test_check_grad_ignore_right(self):
self.check_grad(["Left"], "Out", no_grad_set=set('Right'))
if __name__ == '__main__':
unittest.main()
......@@ -59,7 +59,6 @@ class PySimpleRNNTest(unittest.TestCase):
def test_forward(self):
output = self.rnn.forward()
print 'output', output
def create_tensor(scope, name, shape, np_data):
......@@ -103,7 +102,7 @@ class TestRecurrentOp(unittest.TestCase):
ctx = core.DeviceContext.create(core.CPUPlace())
self.rnnop.infer_shape(self.scope)
self.rnnop.run(self.scope, ctx)
return np.array(self.scope.find_var("h").get_tensor())
return np.array(self.scope.find_var("h@mem").get_tensor())
def create_global_variables(self):
# create inlink
......@@ -123,8 +122,7 @@ class TestRecurrentOp(unittest.TestCase):
create_tensor(self.scope, "h_boot", [self.batch_size, self.input_dim],
h_boot_np_data)
self.scope.new_var("step_scopes")
self.scope.new_var("h@alias")
self.scope.new_var("h")
self.scope.new_var("h@mem")
def create_rnn_op(self):
# create RNNOp
......@@ -134,20 +132,18 @@ class TestRecurrentOp(unittest.TestCase):
boot_memories=["h_boot"],
step_net="stepnet",
# outputs
outlinks=["h"],
outlinks=["h@mem"],
step_scopes="step_scopes",
# attributes
inlink_alias=["x@alias"],
outlink_alias=["h@alias"],
pre_memories=["h@pre"],
memories=["h@alias"])
memories=["h@mem"])
def create_step_net(self):
stepnet = core.Net.create()
x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx")
x_fc_op = Operator("mul", X="x", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("add", X="Wx", Y="Uh", Out="sum")
sig_op = Operator("sigmoid", X="sum", Y="h@alias")
sig_op = Operator("sigmoid", X="sum", Y="h@mem")
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.append_op(op)
......
import unittest
import numpy as np
from op_test import OpTest
class TestSigmoidOp(OpTest):
def setUp(self):
self.op_type = "sigmoid"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.007)
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestTransposeOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = "transpose"
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {'axis': list(self.axis)}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def initTestCase(self):
self.shape = (3, 4)
self.axis = (1, 0)
class TestCase0(TestTransposeOp):
def initTestCase(self):
self.shape = (3, )
self.axis = (0, )
class TestCase1(TestTransposeOp):
def initTestCase(self):
self.shape = (3, 4, 5)
self.axis = (0, 2, 1)
class TestCase2(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
class TestCase3(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
class TestCase4(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册