提交 adec0d30 编写于 作者: Y Yu Yang

Simplify SumOp Kernel

上级 9ff1fd41
...@@ -103,12 +103,16 @@ set(DEPS_OPS ...@@ -103,12 +103,16 @@ set(DEPS_OPS
recurrent_op recurrent_op
cond_op cond_op
cross_entropy_op cross_entropy_op
softmax_with_cross_entropy_op) softmax_with_cross_entropy_op
sum_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op) DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy) op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
......
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/operators/sum_op.h" #include "paddle/operators/sum_op.h"
#include <vector> #include <vector>
#include "paddle/operators/net_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -57,21 +58,23 @@ or not. But the output only shares the LoD with the first input. ...@@ -57,21 +58,23 @@ or not. But the output only shares the LoD with the first input.
} }
}; };
class SumGradOp : public framework::OperatorWithKernel { class SumGradOp : public NetOp {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; SumGradOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: NetOp(type, inputs, outputs, attrs) {
auto& x_grad_names = Outputs(framework::GradVarName("X"));
auto out_grad_name = this->Input(framework::GradVarName("Out"));
protected: framework::AttributeMap grad_attrs;
void InferShape(framework::InferShapeContextBase* ctx) const override { grad_attrs["scale"] = 1.0f;
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); for (auto& x_grad_name : x_grad_names) {
auto x_grad_names = ctx->Outputs(framework::GradVarName("X")); AppendOp(framework::OpRegistry::CreateOp(
size_t x_length = x_grad_names.size(); "scale", {{"X", {out_grad_name}}}, {{"Out", {x_grad_name}}},
std::vector<framework::DDim> x_grad_dims; grad_attrs));
x_grad_dims.reserve(x_length);
for (size_t i = 0; i < x_length; ++i) {
x_grad_dims.push_back(out_grad_dims);
} }
ctx->SetOutputsDim(framework::GradVarName("X"), x_grad_dims); CompleteAddOp(false);
} }
}; };
...@@ -81,5 +84,3 @@ class SumGradOp : public framework::OperatorWithKernel { ...@@ -81,5 +84,3 @@ class SumGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp); REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sum_grad,
ops::SumGradKernel<paddle::platform::CPUPlace, float>);
...@@ -13,6 +13,4 @@ limitations under the License. */ ...@@ -13,6 +13,4 @@ limitations under the License. */
#include "paddle/operators/sum_op.h" #include "paddle/operators/sum_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sum_grad, \ No newline at end of file
ops::SumGradKernel<paddle::platform::GPUPlace, float>);
...@@ -42,24 +42,5 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -42,24 +42,5 @@ class SumKernel : public framework::OpKernel<T> {
} }
}; };
template <typename Place, typename T>
class SumGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>(framework::GradVarName("Out"));
auto outs = context.MultiOutput<Tensor>(framework::GradVarName("X"));
for (auto out : outs) {
out->mutable_data<T>(context.GetPlace());
}
auto place = context.GetEigenDevice<Place>();
auto in = EigenVector<T>::Flatten(*input);
for (auto out : outs) {
auto result = EigenVector<T>::Flatten(*out);
result.device(place) = in;
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册