提交 fddaf0c4 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #4562 from reyoung/feature/simplify_sum_op

Feature/simplify sum op
......@@ -39,28 +39,6 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework;
TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<f::OperatorBase> add_op(f::OpRegistry::CreateOp(
"sum", {{"X", {"x", "y"}}}, {{"Out", {"out"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_add_op =
f::OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(grad_add_op->Inputs().size(), 1UL);
EXPECT_EQ(grad_add_op->Outputs().size(), 1UL);
EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out"));
auto &outputs = grad_add_op->Outputs(f::GradVarName("X"));
EXPECT_EQ(2UL, outputs.size());
auto in_output = [&outputs](const std::string &name) {
for (auto &output_name : outputs) {
if (output_name == name) return true;
}
return false;
};
EXPECT_TRUE(in_output(f::GradVarName("x")));
EXPECT_TRUE(in_output(f::GradVarName("y")));
}
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
......
......@@ -103,12 +103,16 @@ set(DEPS_OPS
recurrent_op
cond_op
cross_entropy_op
softmax_with_cross_entropy_op)
softmax_with_cross_entropy_op
sum_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/operators/sum_op.h"
#include <vector>
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
......@@ -57,21 +58,23 @@ or not. But the output only shares the LoD with the first input.
}
};
class SumGradOp : public framework::OperatorWithKernel {
class SumGradOp : public NetOp {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
SumGradOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: NetOp(type, inputs, outputs, attrs) {
auto& x_grad_names = Outputs(framework::GradVarName("X"));
auto out_grad_name = this->Input(framework::GradVarName("Out"));
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_grad_names = ctx->Outputs(framework::GradVarName("X"));
size_t x_length = x_grad_names.size();
std::vector<framework::DDim> x_grad_dims;
x_grad_dims.reserve(x_length);
for (size_t i = 0; i < x_length; ++i) {
x_grad_dims.push_back(out_grad_dims);
framework::AttributeMap grad_attrs;
grad_attrs["scale"] = 1.0f;
for (auto& x_grad_name : x_grad_names) {
AppendOp(framework::OpRegistry::CreateOp(
"scale", {{"X", {out_grad_name}}}, {{"Out", {x_grad_name}}},
grad_attrs));
}
ctx->SetOutputsDim(framework::GradVarName("X"), x_grad_dims);
CompleteAddOp(false);
}
};
......@@ -81,5 +84,3 @@ class SumGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sum_grad,
ops::SumGradKernel<paddle::platform::CPUPlace, float>);
......@@ -14,5 +14,3 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sum_grad,
ops::SumGradKernel<paddle::platform::GPUPlace, float>);
......@@ -42,24 +42,5 @@ class SumKernel : public framework::OpKernel<T> {
}
};
template <typename Place, typename T>
class SumGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>(framework::GradVarName("Out"));
auto outs = context.MultiOutput<Tensor>(framework::GradVarName("X"));
for (auto out : outs) {
out->mutable_data<T>(context.GetPlace());
}
auto place = context.GetEigenDevice<Place>();
auto in = EigenVector<T>::Flatten(*input);
for (auto out : outs) {
auto result = EigenVector<T>::Flatten(*out);
result.device(place) = in;
}
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册