提交 94b94e5b 编写于 作者: Z zchen0211

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop

......@@ -39,28 +39,6 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework;
TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<f::OperatorBase> add_op(f::OpRegistry::CreateOp(
"sum", {{"X", {"x", "y"}}}, {{"Out", {"out"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_add_op =
f::OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(grad_add_op->Inputs().size(), 1UL);
EXPECT_EQ(grad_add_op->Outputs().size(), 1UL);
EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out"));
auto &outputs = grad_add_op->Outputs(f::GradVarName("X"));
EXPECT_EQ(2UL, outputs.size());
auto in_output = [&outputs](const std::string &name) {
for (auto &output_name : outputs) {
if (output_name == name) return true;
}
return false;
};
EXPECT_TRUE(in_output(f::GradVarName("x")));
EXPECT_TRUE(in_output(f::GradVarName("y")));
}
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
......
......@@ -103,12 +103,16 @@ set(DEPS_OPS
recurrent_op
cond_op
cross_entropy_op
softmax_with_cross_entropy_op)
softmax_with_cross_entropy_op
sum_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -27,6 +27,8 @@ class SGDOp : public framework::OperatorWithKernel {
"Input(param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("grad"),
"Input(grad) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("learning_rate"),
"Input(learning_rate) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("param_out"),
"Output(param_out) of SGDOp should not be null.");
......@@ -42,9 +44,9 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter");
AddInput("learning_rate", "learning rate of sgd");
AddInput("grad", "input gradient");
AddOutput("param_out", "output parameter");
AddAttr<float>("learning_rate", "learning rate of sgd");
AddComment(R"DOC(
Simplest sgd algorithm.
......
......@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>("param_out");
float lr = ctx.Attr<float>("learning_rate");
float lr = *ctx.Input<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace());
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/operators/sum_op.h"
#include <vector>
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
......@@ -57,21 +58,23 @@ or not. But the output only shares the LoD with the first input.
}
};
class SumGradOp : public framework::OperatorWithKernel {
class SumGradOp : public NetOp {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
SumGradOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: NetOp(type, inputs, outputs, attrs) {
auto& x_grad_names = Outputs(framework::GradVarName("X"));
auto out_grad_name = this->Input(framework::GradVarName("Out"));
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_grad_names = ctx->Outputs(framework::GradVarName("X"));
size_t x_length = x_grad_names.size();
std::vector<framework::DDim> x_grad_dims;
x_grad_dims.reserve(x_length);
for (size_t i = 0; i < x_length; ++i) {
x_grad_dims.push_back(out_grad_dims);
framework::AttributeMap grad_attrs;
grad_attrs["scale"] = 1.0f;
for (auto& x_grad_name : x_grad_names) {
AppendOp(framework::OpRegistry::CreateOp(
"scale", {{"X", {out_grad_name}}}, {{"Out", {x_grad_name}}},
grad_attrs));
}
ctx->SetOutputsDim(framework::GradVarName("X"), x_grad_dims);
CompleteAddOp(false);
}
};
......@@ -81,5 +84,3 @@ class SumGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sum_grad,
ops::SumGradKernel<paddle::platform::CPUPlace, float>);
......@@ -14,5 +14,3 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sum_grad,
ops::SumGradKernel<paddle::platform::GPUPlace, float>);
......@@ -42,24 +42,5 @@ class SumKernel : public framework::OpKernel<T> {
}
};
template <typename Place, typename T>
class SumGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>(framework::GradVarName("Out"));
auto outs = context.MultiOutput<Tensor>(framework::GradVarName("X"));
for (auto out : outs) {
out->mutable_data<T>(context.GetPlace());
}
auto place = context.GetEigenDevice<Place>();
auto in = EigenVector<T>::Flatten(*input);
for (auto out : outs) {
auto result = EigenVector<T>::Flatten(*out);
result.device(place) = in;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -143,6 +143,13 @@ All parameter, weight, gradient are variables in Paddle.
.def("set_int",
[](Variable &var, int val) -> void { *var.GetMutable<int>() = val; })
.def("get_int", [](const Variable &var) -> int { return var.Get<int>(); })
.def("is_float", [](const Variable &var) { return var.IsType<float>(); })
.def("set_float",
[](Variable &var, float val) -> void {
*var.GetMutable<float>() = val;
})
.def("get_float",
[](const Variable &var) -> float { return var.Get<float>(); })
.def("get_tensor",
[](Variable &self) -> LoDTensor * {
return self.GetMutable<LoDTensor>();
......
......@@ -46,12 +46,17 @@ def create_op(scope, op_type, inputs, outputs, attrs):
def set_input(scope, op, inputs, place):
def __set_input__(var_name, var):
tensor = scope.find_var(var_name).get_tensor()
if isinstance(var, tuple):
tensor.set_lod(var[1])
var = var[0]
tensor.set_dims(var.shape)
tensor.set(var, place)
if isinstance(var, tuple) or isinstance(var, np.ndarray):
tensor = scope.find_var(var_name).get_tensor()
if isinstance(var, tuple):
tensor.set_lod(var[1])
var = var[0]
tensor.set_dims(var.shape)
tensor.set(var, place)
elif isinstance(var, float):
scope.find_var(var_name).set_float(var)
elif isinstance(var, int):
scope.find_var(var_name).set_int(var)
for in_name, in_dup in Operator.get_op_inputs(op.type()):
if in_name in inputs:
......
......@@ -10,8 +10,7 @@ class TestSGDOp(OpTest):
g = np.random.random((102, 105)).astype("float32")
lr = 0.1
self.inputs = {'param': w, 'grad': g}
self.attrs = {'learning_rate': lr}
self.inputs = {'param': w, 'grad': g, 'learning_rate': lr}
self.outputs = {'param_out': w - lr * g}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册