提交 efc119b4 编写于 作者: Y Yu Yang

Add type_alias to import framework into ops

Make implement an operator less noisy.
上级 0c2790f7
...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/add_op.h" #include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class AddOp : public framework::OperatorWithKernel { class AddOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two"); PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one"); PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -35,10 +32,10 @@ protected: ...@@ -35,10 +32,10 @@ protected:
} }
}; };
class AddOpMaker : public framework::OpProtoAndCheckerMaker { class AddOpMaker : public OpProtoAndCheckerMaker {
public: public:
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op"); AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op"); AddInput("Y", "The second input of add op");
AddOutput("Out", "The output of add op"); AddOutput("Out", "The output of add op");
...@@ -50,11 +47,10 @@ The equation is: Out = X + Y ...@@ -50,11 +47,10 @@ The equation is: Out = X + Y
} }
}; };
class AddOpGrad : public framework::OperatorWithKernel { class AddOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {}
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "AddOpGrad"; LOG(INFO) << "AddOpGrad";
return ""; return "";
...@@ -64,7 +60,6 @@ protected: ...@@ -64,7 +60,6 @@ protected:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad); REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel<ops::CPUPlace, float>);
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"
REGISTER_OP_GPU_KERNEL(add_two, REGISTER_OP_GPU_KERNEL(add_two, ops::AddKernel<ops::GPUPlace, float>);
paddle::operators::AddKernel<paddle::platform::GPUPlace, float>);
\ No newline at end of file
...@@ -13,27 +13,24 @@ See the License for the specific language governing permissions and ...@@ -13,27 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class AddKernel : public framework::OpKernel { class AddKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto input0 = context.Input(0)->Get<framework::Tensor>(); auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>(); auto input1 = context.Input(1)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) = *(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(input0) + EigenVector<T>::Flatten(input0) + EigenVector<T>::Flatten(input1);
framework::EigenVector<T>::Flatten(input1);
} }
}; };
......
...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/cross_entropy_op.h" #include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel { class OnehotCrossEntropyOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, PADDLE_ENFORCE(inputs.size() == 2,
"Input size of OnehotCrossEntropyOp must be two"); "Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, PADDLE_ENFORCE(outputs.size() == 1,
...@@ -35,15 +32,14 @@ protected: ...@@ -35,15 +32,14 @@ protected:
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2."); PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
PADDLE_ENFORCE(outputs[0]->dims().size() == 1, PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
"label's dimension must be 1."); "label's dimension must be 1.");
outputs[0]->Resize(framework::make_ddim({inputs[0]->dims()[0]})); outputs[0]->Resize({inputs[0]->dims()[0]});
} }
}; };
class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker {
public: public:
OnehotCrossEntropyOpMaker(framework::OpProto *proto, OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp"); AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp"); AddInput("label", "The second input of OnehotCrossEntropyOp");
AddOutput("Y", "The output of OnehotCrossEntropyOp"); AddOutput("Y", "The output of OnehotCrossEntropyOp");
...@@ -59,9 +55,7 @@ OnehotCrossEntropy Operator. ...@@ -59,9 +55,7 @@ OnehotCrossEntropy Operator.
} // namespace paddle } // namespace paddle
REGISTER_OP(onehot_cross_entropy, REGISTER_OP(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOp, ops::OnehotCrossEntropyOp,
paddle::operators::OnehotCrossEntropyOpMaker); ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
onehot_cross_entropy, ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
paddle::operators::OnehotCrossEntropyOpKernel<::paddle::platform::CPUPlace,
float>);
#include "paddle/operators/cross_entropy_op.h" #include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOpKernel< ops::OnehotCrossEntropyOpKernel<ops::GPUPlace, float>);
::paddle::platform::GPUPlace, float>); \ No newline at end of file
\ No newline at end of file
...@@ -13,23 +13,21 @@ See the License for the specific language governing permissions and ...@@ -13,23 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public framework::OpKernel { class OnehotCrossEntropyOpKernel : public OpKernel {
public: public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); } constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto X = context.Input(0)->Get<framework::Tensor>(); auto X = context.Input(0)->Get<Tensor>();
const T* X_data = X.data<T>(); const T* X_data = X.data<T>();
const int* label_data = const int* label_data = context.Input(1)->Get<Tensor>().data<int>();
context.Input(1)->Get<framework::Tensor>().data<int>(); auto* Y = context.Output(0)->GetMutable<Tensor>();
auto* Y = context.Output(0)->GetMutable<framework::Tensor>();
Y->mutable_data<T>(context.GetPlace()); Y->mutable_data<T>(context.GetPlace());
......
...@@ -12,41 +12,38 @@ ...@@ -12,41 +12,38 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/net.h" #include "type_alias.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FullyConnectedOp : public framework::PlainNet { class FullyConnectedOp : public PlainNet {
public: public:
void Init() override { void Init() override {
AddOp(framework::OpRegistry::CreateOp("mul", AddOp(OpRegistry::CreateOp("mul",
{ {
Input("X"), Input("W"), Input("X"), Input("W"),
}, },
{Output("before_act")}, {Output("before_act")},
{})); {}));
auto b = Input("b"); auto b = Input("b");
if (b != framework::OperatorBase::EMPTY_VAR_NAME()) { if (b != EMPTY_VAR_NAME()) {
AddOp(framework::OpRegistry::CreateOp("rowwise_add", AddOp(OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")}, {Output("before_act"), Input("b")},
{Output("before_act")}, {Output("before_act")},
{})); {}));
} }
auto activation = GetAttr<std::string>("activation"); auto activation = GetAttr<std::string>("activation");
AddOp(framework::OpRegistry::CreateOp( AddOp(OpRegistry::CreateOp(
activation, {Output("before_act")}, {Output("Y")}, {})); activation, {Output("before_act")}, {Output("Y")}, {}));
CompleteAddOp(false); CompleteAddOp(false);
} }
}; };
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker { class FullyConnectedOpMaker : public OpProtoAndCheckerMaker {
public: public:
FullyConnectedOpMaker(framework::OpProto *proto, FullyConnectedOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator"); AddInput("X", "the input of fc operator");
AddInput("W", "the weight of fc operator"); AddInput("W", "the weight of fc operator");
...@@ -71,6 +68,4 @@ USE_OP(rowwise_add); ...@@ -71,6 +68,4 @@ USE_OP(rowwise_add);
USE_OP(sigmoid); USE_OP(sigmoid);
USE_OP(softmax); USE_OP(softmax);
REGISTER_OP(fc, REGISTER_OP(fc, ops::FullyConnectedOp, ops::FullyConnectedOpMaker);
paddle::operators::FullyConnectedOp,
paddle::operators::FullyConnectedOpMaker);
...@@ -13,17 +13,14 @@ ...@@ -13,17 +13,14 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class MulOp : public framework::OperatorWithKernel { class MulOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs"); PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs");
auto dim0 = inputs[0]->dims(); auto dim0 = inputs[0]->dims();
auto dim1 = inputs[1]->dims(); auto dim1 = inputs[1]->dims();
...@@ -37,10 +34,10 @@ protected: ...@@ -37,10 +34,10 @@ protected:
} }
}; };
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public OpProtoAndCheckerMaker {
public: public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op"); AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op"); AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op"); AddOutput("Out", "The output of mul op");
...@@ -52,11 +49,10 @@ The equation is: Out = X * Y ...@@ -52,11 +49,10 @@ The equation is: Out = X * Y
} }
}; };
class MulOpGrad : public framework::OperatorWithKernel { class MulOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {}
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "MulGrad"; LOG(INFO) << "MulGrad";
return ""; return "";
...@@ -66,8 +62,7 @@ protected: ...@@ -66,8 +62,7 @@ protected:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker); REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, paddle::operators::MulOpGrad); REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<ops::CPUPlace, float>);
mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>);
...@@ -13,8 +13,5 @@ ...@@ -13,8 +13,5 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(mul, REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
paddle::operators::MulKernel<paddle::platform \ No newline at end of file
::GPUPlace, float>);
\ No newline at end of file
...@@ -14,30 +14,27 @@ ...@@ -14,30 +14,27 @@
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = { Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}}; {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input(0)->Get<framework::Tensor>(); auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>(); auto input1 = context.Input(1)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::EigenMatrix<T>::From(*output).device( EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
*(context.GetEigenDevice<Place>())) = EigenMatrix<T>::From(input0).contract(EigenMatrix<T>::From(input1),
framework::EigenMatrix<T>::From(input0).contract( dim_pair);
framework::EigenMatrix<T>::From(input1), dim_pair);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/rowwise_add_op.h" #include "paddle/operators/rowwise_add_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel { class RowWiseAddOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2UL, "Two inputs is needed by rowwise add"); PADDLE_ENFORCE(inputs.size() == 2UL, "Two inputs is needed by rowwise add");
auto dim0 = inputs[0]->dims(); auto dim0 = inputs[0]->dims();
auto dim1 = inputs[1]->dims(); auto dim1 = inputs[1]->dims();
...@@ -34,11 +32,10 @@ protected: ...@@ -34,11 +32,10 @@ protected:
} }
}; };
class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker { class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(framework::OpProto *proto, RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix"); AddInput("X", "The left input of row-wise add op, must be matrix");
AddInput("b", "The right input of row-wise add op, must be vector"); AddInput("b", "The right input of row-wise add op, must be vector");
AddOutput("Out", "The output of row-wise add op"); AddOutput("Out", "The output of row-wise add op");
...@@ -53,9 +50,6 @@ for i in xrange(X.shape[0]): ...@@ -53,9 +50,6 @@ for i in xrange(X.shape[0]):
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(rowwise_add, REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker);
paddle::operators::RowWiseAddOp, REGISTER_OP_CPU_KERNEL(rowwise_add,
paddle::operators::RowWiseAddOpMaker); ops::RowWiseAddKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
rowwise_add,
paddle::operators::RowWiseAddKernel<paddle::platform::CPUPlace, float>);
#include "paddle/framework/op_registry.h"
#include "paddle/operators/rowwise_add_op.h" #include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(rowwise_add,
rowwise_add, ops::RowWiseAddKernel<ops::GPUPlace, float>);
paddle::operators::RowWiseAddKernel<paddle::platform ::GPUPlace, float>);
...@@ -13,25 +13,23 @@ ...@@ -13,25 +13,23 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class RowWiseAddKernel : public framework::OpKernel { class RowWiseAddKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto in0 = context.Input(0)->Get<framework::Tensor>(); auto in0 = context.Input(0)->Get<Tensor>();
auto in1 = context.Input(1)->Get<framework::Tensor>(); auto in1 = context.Input(1)->Get<Tensor>();
auto* out = context.Output(0)->GetMutable<framework::Tensor>(); auto* out = context.Output(0)->GetMutable<Tensor>();
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto input = framework::EigenMatrix<T>::From(in0); auto input = EigenMatrix<T>::From(in0);
auto bias = framework::EigenVector<T>::From(in1); auto bias = EigenVector<T>::From(in1);
auto output = framework::EigenMatrix<T>::From(*out); auto output = EigenMatrix<T>::From(*out);
const int bias_size = bias.dimension(0); const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size; const int rest_size = input.size() / bias_size;
......
...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/sgd_op.h" #include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SGDOp : public framework::OperatorWithKernel { class SGDOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two"); PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one"); PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set"); PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set");
...@@ -35,10 +32,10 @@ protected: ...@@ -35,10 +32,10 @@ protected:
} }
}; };
class SGDOpMaker : public framework::OpProtoAndCheckerMaker { class SGDOpMaker : public OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SGDOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter"); AddInput("param", "input parameter");
AddInput("grad", "input gradient"); AddInput("grad", "input gradient");
AddOutput("param_out", "output parameter"); AddOutput("param_out", "output parameter");
...@@ -55,7 +52,5 @@ param_out = param - learning_rate * grad; ...@@ -55,7 +52,5 @@ param_out = param - learning_rate * grad;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker); REGISTER_OP(sgd, ops::SGDOp, ops::SGDOpMaker);
typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float> REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<ops::CPUPlace, float>);
SGDOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float);
#include "paddle/operators/sgd_op.h" #include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float; REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float); \ No newline at end of file
\ No newline at end of file
...@@ -13,28 +13,24 @@ See the License for the specific language governing permissions and ...@@ -13,28 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel { class SGDOpKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& ctx) const override { void Compute(const KernelContext& ctx) const override {
auto param = ctx.Input("param")->Get<framework::Tensor>(); auto param = ctx.Input("param")->Get<Tensor>();
auto grad = ctx.Input("grad")->Get<framework::Tensor>(); auto grad = ctx.Input("grad")->Get<Tensor>();
auto* param_out = ctx.Output(0)->GetMutable<framework::Tensor>(); auto* param_out = ctx.Output(0)->GetMutable<Tensor>();
float lr = ctx.op_.GetAttr<float>("learning_rate"); float lr = ctx.op_.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
framework::EigenVector<T>::Flatten(*param_out) EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) =
.device(*(ctx.GetEigenDevice<Place>())) = EigenVector<T>::Flatten(param) - lr * EigenVector<T>::Flatten(grad);
framework::EigenVector<T>::Flatten(param) -
lr * framework::EigenVector<T>::Flatten(grad);
} }
}; };
......
...@@ -13,37 +13,33 @@ ...@@ -13,37 +13,33 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/sigmoid_op.h" #include "paddle/operators/sigmoid_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SigmoidOp : public framework::OperatorWithKernel { class SigmoidOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input"); PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output"); PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output");
outputs[0]->Resize(inputs[0]->dims()); outputs[0]->Resize(inputs[0]->dims());
} }
}; };
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class SigmoidOpMaker : public OpProtoAndCheckerMaker {
public: public:
SigmoidOpMaker(framework::OpProto *proto, SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "sigmoid input"); AddInput("X", "sigmoid input");
AddOutput("Y", "sigmoid output"); AddOutput("Y", "sigmoid output");
AddComment("Sigmoid function"); AddComment("Sigmoid function");
} }
}; };
class SigmoidOpGrad : public framework::OperatorWithKernel { class SigmoidOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {}
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "SigmoidGrad"; LOG(INFO) << "SigmoidGrad";
return ""; return "";
...@@ -53,11 +49,7 @@ protected: ...@@ -53,11 +49,7 @@ protected:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(sigmoid, REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker);
paddle::operators::SigmoidOp, REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad);
paddle::operators::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, paddle::operators::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::CPUPlace, float>);
sigmoid,
paddle::operators::SigmoidKernel<paddle::platform::CPUPlace, float>);
#include "paddle/operators/sigmoid_op.h" #include "paddle/operators/sigmoid_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>);
sigmoid, paddle::operators::SigmoidKernel<paddle::platform::GPUPlace, float>);
...@@ -14,25 +14,23 @@ ...@@ -14,25 +14,23 @@
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidKernel : public framework::OpKernel { class SigmoidKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto input = context.Input(0)->Get<framework::Tensor>(); auto input = context.Input(0)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) = *(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * framework::EigenVector<T>::Flatten(input)).exp()); 1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(input)).exp());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -12,16 +12,14 @@ ...@@ -12,16 +12,14 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/softmax_op.h" #include "paddle/operators/softmax_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel { class SoftmaxOp : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax"); PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax");
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, PADDLE_ENFORCE(inputs[0]->dims().size() == 2,
"The input of softmax op must be matrix"); "The input of softmax op must be matrix");
...@@ -31,10 +29,9 @@ protected: ...@@ -31,10 +29,9 @@ protected:
} }
}; };
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
public: public:
SoftmaxOpMaker(framework::OpProto *proto, SoftmaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "input of softmax"); AddInput("X", "input of softmax");
AddOutput("Y", "output of softmax"); AddOutput("Y", "output of softmax");
...@@ -42,11 +39,10 @@ public: ...@@ -42,11 +39,10 @@ public:
} }
}; };
class SoftmaxOpGrad : public framework::OperatorWithKernel { class SoftmaxOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<const framework::Tensor *> &inputs, const std::vector<Tensor *> &outputs) const override {}
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "SoftmaxOpGrad"; LOG(INFO) << "SoftmaxOpGrad";
return ""; return "";
...@@ -56,9 +52,6 @@ protected: ...@@ -56,9 +52,6 @@ protected:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_GRADIENT_OP(softmax, softmax_grad, paddle::operators::SoftmaxOpGrad); REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel<ops::CPUPlace, float>);
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h" #include "paddle/operators/softmax_op.h"
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(softmax, ops::SoftmaxKernel<ops::GPUPlace, float>);
softmax, paddle::operators::SoftmaxKernel<paddle::platform::GPUPlace, float>);
...@@ -14,23 +14,21 @@ ...@@ -14,23 +14,21 @@
#pragma once #pragma once
#include "glog/logging.h" #include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxKernel : public framework::OpKernel { class SoftmaxKernel : public OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const KernelContext& context) const override {
auto input = context.Input(0)->Get<framework::Tensor>(); auto input = context.Input(0)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto logits = framework::EigenMatrix<T>::From(input); auto logits = EigenMatrix<T>::From(input);
auto softmax = framework::EigenMatrix<T>::From(*output); auto softmax = EigenMatrix<T>::From(*output);
const int kBatchDim = 0; const int kBatchDim = 0;
const int kClassDim = 1; const int kClassDim = 1;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using OpKernel = framework::OpKernel;
using KernelContext = framework::KernelContext;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using OperatorWithKernel = framework::OperatorWithKernel;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace;
using PlainNet = framework::PlainNet;
using OpRegistry = framework::OpRegistry;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册