提交 55115ac6 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #3067 from reyoung/make_network_op

Make network op
......@@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
......@@ -20,17 +20,7 @@
namespace paddle {
namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
auto grad_ops = std::make_shared<PlainNet>();
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}
void PlainNet::CompleteAddOp(bool calc) {
void NetOp::CompleteAddOp(bool calc) {
add_op_done_ = true;
if (!calc) return;
std::unordered_set<std::string> input_set;
......@@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
attrs_["temporary_index"] = tmp_index;
}
std::string PlainNet::DebugString() const {
std::string NetOp::DebugString() const {
std::ostringstream os;
os << OperatorBase::DebugString() << std::endl;
for (auto& op : ops_) {
......@@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
return os.str();
}
bool NetOp::IsNetOp() const { return true; }
} // namespace framework
} // namespace paddle
......@@ -37,21 +37,7 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs
* it defines.
*/
class Net : public OperatorBase {
public:
virtual void AddOp(const std::shared_ptr<OperatorBase>& op) = 0;
virtual void CompleteAddOp(bool calc) = 0;
};
using NetPtr = std::shared_ptr<Net>;
/**
* @brief a basic implementation of Net.
*
* PlainNet is a very simple Net, it create a list of operators, and run them
* sequentially following the order they added.
*/
class PlainNet : public Net {
class NetOp : public OperatorBase {
public:
/**
* Infer all the operators' input and output variables' shapes, will be called
......@@ -80,15 +66,17 @@ class PlainNet : public Net {
/**
* @brief Add an operator by ptr
*/
void AddOp(const std::shared_ptr<OperatorBase>& op) override {
void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
ops_.push_back(op);
}
void CompleteAddOp(bool calculate = true) override;
void CompleteAddOp(bool calculate = true);
std::string DebugString() const override;
bool IsNetOp() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_;
private:
......@@ -100,7 +88,5 @@ class PlainNet : public Net {
}
};
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
} // namespace framework
} // namespace paddle
......@@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}
TEST(OpKernel, all) {
auto net = std::make_shared<PlainNet>();
auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>();
......@@ -71,28 +71,21 @@ TEST(OpKernel, all) {
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
}
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
}
// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
//! TODO(yuyang18): Refine Backward Op.
// TEST(AddBackwardOp, TestGradOp) {
// auto net = std::make_shared<NetOp>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(
// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""},
// {}));
// auto grad_ops = AddBackwardOp(net);
// for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
//}
} // namespace framework
} // namespace paddle
syntax="proto2";
package paddle.framework;
import "op_proto.proto";
message NetDesc {
// network identification
optional string name = 1;
// operator contains in network
repeated OpProto operators = 2;
// network type to run with. e.g "plainNet", "DAG"
optional string net_type = 3;
// num worker always
optional int32 num_workers = 4;
}
......@@ -90,15 +90,17 @@ class OperatorBase {
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto`
virtual bool IsNetOp() const { return false; }
//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
// Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const;
// Get a output with argument's name described in `op_proto`
//! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const;
// Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const;
public:
......
......@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
class AddOp : public framework::OperatorWithKernel {
class AddOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(
......@@ -35,10 +32,10 @@ protected:
}
};
class AddOpMaker : public framework::OpProtoAndCheckerMaker {
class AddOpMaker : public OpProtoAndCheckerMaker {
public:
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
AddOutput("Out", "The output of add op");
......@@ -50,11 +47,10 @@ The equation is: Out = X + Y
}
};
class AddOpGrad : public framework::OperatorWithKernel {
class AddOpGrad : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "AddOpGrad";
return "";
......@@ -64,7 +60,6 @@ protected:
} // namespace operators
} // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel<ops::CPUPlace, float>);
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"
REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace, float>);
\ No newline at end of file
REGISTER_OP_GPU_KERNEL(add_two, ops::AddKernel<ops::GPUPlace, float>);
......@@ -13,27 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class AddKernel : public framework::OpKernel {
class AddKernel : public OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
void Compute(const KernelContext& context) const override {
auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<Tensor>();
auto output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).device(
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(input0) +
framework::EigenVector<T>::Flatten(input1);
EigenVector<T>::Flatten(input0) + EigenVector<T>::Flatten(input1);
}
};
......
......@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
class OnehotCrossEntropyOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2,
"Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(outputs.size() == 1,
......@@ -35,15 +32,14 @@ protected:
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
"label's dimension must be 1.");
outputs[0]->Resize(framework::make_ddim({inputs[0]->dims()[0]}));
outputs[0]->Resize({inputs[0]->dims()[0]});
}
};
class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker {
public:
OnehotCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp");
AddOutput("Y", "The output of OnehotCrossEntropyOp");
......@@ -59,9 +55,7 @@ OnehotCrossEntropy Operator.
} // namespace paddle
REGISTER_OP(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOp,
paddle::operators::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOpKernel<::paddle::platform::CPUPlace,
float>);
ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOpKernel<
::paddle::platform::GPUPlace, float>);
\ No newline at end of file
ops::OnehotCrossEntropyOpKernel<ops::GPUPlace, float>);
\ No newline at end of file
......@@ -13,23 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public framework::OpKernel {
class OnehotCrossEntropyOpKernel : public OpKernel {
public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }
void Compute(const framework::KernelContext& context) const override {
auto X = context.Input(0)->Get<framework::Tensor>();
void Compute(const KernelContext& context) const override {
auto X = context.Input(0)->Get<Tensor>();
const T* X_data = X.data<T>();
const int* label_data =
context.Input(1)->Get<framework::Tensor>().data<int>();
auto* Y = context.Output(0)->GetMutable<framework::Tensor>();
const int* label_data = context.Input(1)->Get<Tensor>().data<int>();
auto* Y = context.Output(0)->GetMutable<Tensor>();
Y->mutable_data<T>(context.GetPlace());
......
......@@ -12,41 +12,38 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "type_alias.h"
namespace paddle {
namespace operators {
class FullyConnectedOp : public framework::PlainNet {
class FullyConnectedOp : public NetOp {
public:
void Init() override {
AddOp(framework::OpRegistry::CreateOp("mul",
AddOp(OpRegistry::CreateOp("mul",
{
Input("X"), Input("W"),
},
{Output("before_act")},
{}));
auto b = Input("b");
if (b != framework::OperatorBase::EMPTY_VAR_NAME()) {
AddOp(framework::OpRegistry::CreateOp("rowwise_add",
if (b != EMPTY_VAR_NAME()) {
AddOp(OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")},
{Output("before_act")},
{}));
}
auto activation = GetAttr<std::string>("activation");
AddOp(framework::OpRegistry::CreateOp(
AddOp(OpRegistry::CreateOp(
activation, {Output("before_act")}, {Output("Y")}, {}));
CompleteAddOp(false);
}
};
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
class FullyConnectedOpMaker : public OpProtoAndCheckerMaker {
public:
FullyConnectedOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
FullyConnectedOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator");
AddInput("W", "the weight of fc operator");
......@@ -71,6 +68,4 @@ USE_OP(rowwise_add);
USE_OP(sigmoid);
USE_OP(softmax);
REGISTER_OP(fc,
paddle::operators::FullyConnectedOp,
paddle::operators::FullyConnectedOpMaker);
REGISTER_OP(fc, ops::FullyConnectedOp, ops::FullyConnectedOpMaker);
......@@ -13,17 +13,14 @@
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
class MulOp : public framework::OperatorWithKernel {
class MulOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs");
auto dim0 = inputs[0]->dims();
auto dim1 = inputs[1]->dims();
......@@ -37,10 +34,10 @@ protected:
}
};
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
class MulOpMaker : public OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
......@@ -52,11 +49,10 @@ The equation is: Out = X * Y
}
};
class MulOpGrad : public framework::OperatorWithKernel {
class MulOpGrad : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "MulGrad";
return "";
......@@ -66,8 +62,7 @@ protected:
} // namespace operators
} // namespace paddle
REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, paddle::operators::MulOpGrad);
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(
mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<ops::CPUPlace, float>);
......@@ -13,8 +13,5 @@
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(mul,
paddle::operators::MulKernel<paddle::platform
::GPUPlace, float>);
\ No newline at end of file
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
\ No newline at end of file
......@@ -14,30 +14,27 @@
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
class MulKernel : public OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
void Compute(const KernelContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
auto input0 = context.Input(0)->Get<Tensor>();
auto input1 = context.Input(1)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace());
framework::EigenMatrix<T>::From(*output).device(
*(context.GetEigenDevice<Place>())) =
framework::EigenMatrix<T>::From(input0).contract(
framework::EigenMatrix<T>::From(input1), dim_pair);
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenMatrix<T>::From(input0).contract(EigenMatrix<T>::From(input1),
dim_pair);
}
};
} // namespace operators
......
......@@ -13,15 +13,13 @@
limitations under the License. */
#include "paddle/operators/rowwise_add_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel {
class RowWiseAddOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2UL, "Two inputs is needed by rowwise add");
auto dim0 = inputs[0]->dims();
auto dim1 = inputs[1]->dims();
......@@ -34,11 +32,10 @@ protected:
}
};
class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix");
AddInput("b", "The right input of row-wise add op, must be vector");
AddOutput("Out", "The output of row-wise add op");
......@@ -53,9 +50,6 @@ for i in xrange(X.shape[0]):
} // namespace operators
} // namespace paddle
REGISTER_OP(rowwise_add,
paddle::operators::RowWiseAddOp,
paddle::operators::RowWiseAddOpMaker);
REGISTER_OP_CPU_KERNEL(
rowwise_add,
paddle::operators::RowWiseAddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker);
REGISTER_OP_CPU_KERNEL(rowwise_add,
ops::RowWiseAddKernel<ops::CPUPlace, float>);
#include "paddle/framework/op_registry.h"
#include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL(
rowwise_add,
paddle::operators::RowWiseAddKernel<paddle::platform ::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(rowwise_add,
ops::RowWiseAddKernel<ops::GPUPlace, float>);
......@@ -13,25 +13,23 @@
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class RowWiseAddKernel : public framework::OpKernel {
class RowWiseAddKernel : public OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto in0 = context.Input(0)->Get<framework::Tensor>();
auto in1 = context.Input(1)->Get<framework::Tensor>();
auto* out = context.Output(0)->GetMutable<framework::Tensor>();
void Compute(const KernelContext& context) const override {
auto in0 = context.Input(0)->Get<Tensor>();
auto in1 = context.Input(1)->Get<Tensor>();
auto* out = context.Output(0)->GetMutable<Tensor>();
out->mutable_data<T>(context.GetPlace());
auto input = framework::EigenMatrix<T>::From(in0);
auto bias = framework::EigenVector<T>::From(in1);
auto output = framework::EigenMatrix<T>::From(*out);
auto input = EigenMatrix<T>::From(in0);
auto bias = EigenVector<T>::From(in1);
auto output = EigenMatrix<T>::From(*out);
const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size;
......
......@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
class SGDOp : public framework::OperatorWithKernel {
class SGDOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set");
......@@ -35,10 +32,10 @@ protected:
}
};
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
class SGDOpMaker : public OpProtoAndCheckerMaker {
public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
SGDOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter");
AddInput("grad", "input gradient");
AddOutput("param_out", "output parameter");
......@@ -55,7 +52,5 @@ param_out = param - learning_rate * grad;
} // namespace operators
} // namespace paddle
REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker);
typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float>
SGDOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float);
REGISTER_OP(sgd, ops::SGDOp, ops::SGDOpMaker);
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<ops::CPUPlace, float>);
#include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float);
\ No newline at end of file
REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>);
\ No newline at end of file
......@@ -13,28 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel {
class SGDOpKernel : public OpKernel {
public:
void Compute(const framework::KernelContext& ctx) const override {
auto param = ctx.Input("param")->Get<framework::Tensor>();
auto grad = ctx.Input("grad")->Get<framework::Tensor>();
auto* param_out = ctx.Output(0)->GetMutable<framework::Tensor>();
void Compute(const KernelContext& ctx) const override {
auto param = ctx.Input("param")->Get<Tensor>();
auto grad = ctx.Input("grad")->Get<Tensor>();
auto* param_out = ctx.Output(0)->GetMutable<Tensor>();
float lr = ctx.op_.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace());
framework::EigenVector<T>::Flatten(*param_out)
.device(*(ctx.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(param) -
lr * framework::EigenVector<T>::Flatten(grad);
EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(param) - lr * EigenVector<T>::Flatten(grad);
}
};
......
......@@ -13,37 +13,33 @@
limitations under the License. */
#include "paddle/operators/sigmoid_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class SigmoidOp : public framework::OperatorWithKernel {
class SigmoidOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output");
outputs[0]->Resize(inputs[0]->dims());
}
};
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
class SigmoidOpMaker : public OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "sigmoid input");
AddOutput("Y", "sigmoid output");
AddComment("Sigmoid function");
}
};
class SigmoidOpGrad : public framework::OperatorWithKernel {
class SigmoidOpGrad : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "SigmoidGrad";
return "";
......@@ -53,11 +49,7 @@ protected:
} // namespace operators
} // namespace paddle
REGISTER_OP(sigmoid,
paddle::operators::SigmoidOp,
paddle::operators::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, paddle::operators::SigmoidOpGrad);
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(
sigmoid,
paddle::operators::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::CPUPlace, float>);
#include "paddle/operators/sigmoid_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(
sigmoid, paddle::operators::SigmoidKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>);
......@@ -14,25 +14,23 @@
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class SigmoidKernel : public framework::OpKernel {
class SigmoidKernel : public OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto input = context.Input(0)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
void Compute(const KernelContext& context) const override {
auto input = context.Input(0)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).device(
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * framework::EigenVector<T>::Flatten(input)).exp());
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(input)).exp());
}
};
} // namespace operators
......
......@@ -12,16 +12,14 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/softmax_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel {
class SoftmaxOp : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax");
PADDLE_ENFORCE(inputs[0]->dims().size() == 2,
"The input of softmax op must be matrix");
......@@ -31,10 +29,9 @@ protected:
}
};
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
public:
SoftmaxOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
SoftmaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "input of softmax");
AddOutput("Y", "output of softmax");
......@@ -42,11 +39,10 @@ public:
}
};
class SoftmaxOpGrad : public framework::OperatorWithKernel {
class SoftmaxOpGrad : public OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
void InferShape(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "SoftmaxOpGrad";
return "";
......@@ -56,9 +52,6 @@ protected:
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_GRADIENT_OP(softmax, softmax_grad, paddle::operators::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel<ops::CPUPlace, float>);
#include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h"
REGISTER_OP_GPU_KERNEL(
softmax, paddle::operators::SoftmaxKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(softmax, ops::SoftmaxKernel<ops::GPUPlace, float>);
......@@ -14,23 +14,21 @@
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class SoftmaxKernel : public framework::OpKernel {
class SoftmaxKernel : public OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto input = context.Input(0)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
void Compute(const KernelContext& context) const override {
auto input = context.Input(0)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace());
auto logits = framework::EigenMatrix<T>::From(input);
auto softmax = framework::EigenMatrix<T>::From(*output);
auto logits = EigenMatrix<T>::From(input);
auto softmax = EigenMatrix<T>::From(*output);
const int kBatchDim = 0;
const int kClassDim = 1;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using OpKernel = framework::OpKernel;
using KernelContext = framework::KernelContext;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using OperatorWithKernel = framework::OperatorWithKernel;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace;
using NetOp = framework::NetOp;
using OpRegistry = framework::OpRegistry;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
......@@ -146,22 +146,22 @@ All parameter, weight, gradient are variables in Paddle.
});
ExposeOperator(operator_base);
using PlainNetPtr = std::shared_ptr<pd::PlainNet>;
py::class_<pd::PlainNet, PlainNetPtr> net(m, "Net");
py::class_<pd::NetOp, std::shared_ptr<pd::NetOp>> net(m, "Net");
net.def_static("create",
[]() -> std::shared_ptr<pd::PlainNet> {
auto retv = std::make_shared<pd::PlainNet>();
[]() -> std::shared_ptr<pd::NetOp> {
auto retv = std::make_shared<pd::NetOp>();
retv->type_ = "plain_net";
return retv;
})
.def("add_op", &pd::PlainNet::AddOp)
.def("add_op", &pd::NetOp::AddOp)
.def("add_op",
[](PlainNetPtr& self, const PlainNetPtr& net) -> void {
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(net));
[](pd::NetOp& self, const std::shared_ptr<pd::NetOp>& net) -> void {
self.AddOp(std::static_pointer_cast<pd::OperatorBase>(net));
})
.def("complete_add_op", &pd::PlainNet::CompleteAddOp)
.def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); });
.def("complete_add_op", &pd::NetOp::CompleteAddOp)
.def("complete_add_op",
[](std::shared_ptr<pd::NetOp>& self) { self->CompleteAddOp(); });
ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册