提交 adec0d30 编写于 作者: Y Yu Yang

Simplify SumOp Kernel

上级 9ff1fd41
......@@ -103,12 +103,16 @@ set(DEPS_OPS
recurrent_op
cond_op
cross_entropy_op
softmax_with_cross_entropy_op)
softmax_with_cross_entropy_op
sum_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/operators/sum_op.h"
#include <vector>
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
......@@ -57,21 +58,23 @@ or not. But the output only shares the LoD with the first input.
}
};
class SumGradOp : public framework::OperatorWithKernel {
class SumGradOp : public NetOp {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
SumGradOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: NetOp(type, inputs, outputs, attrs) {
auto& x_grad_names = Outputs(framework::GradVarName("X"));
auto out_grad_name = this->Input(framework::GradVarName("Out"));
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_grad_names = ctx->Outputs(framework::GradVarName("X"));
size_t x_length = x_grad_names.size();
std::vector<framework::DDim> x_grad_dims;
x_grad_dims.reserve(x_length);
for (size_t i = 0; i < x_length; ++i) {
x_grad_dims.push_back(out_grad_dims);
framework::AttributeMap grad_attrs;
grad_attrs["scale"] = 1.0f;
for (auto& x_grad_name : x_grad_names) {
AppendOp(framework::OpRegistry::CreateOp(
"scale", {{"X", {out_grad_name}}}, {{"Out", {x_grad_name}}},
grad_attrs));
}
ctx->SetOutputsDim(framework::GradVarName("X"), x_grad_dims);
CompleteAddOp(false);
}
};
......@@ -81,5 +84,3 @@ class SumGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sum_grad,
ops::SumGradKernel<paddle::platform::CPUPlace, float>);
......@@ -13,6 +13,4 @@ limitations under the License. */
#include "paddle/operators/sum_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sum_grad,
ops::SumGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>);
\ No newline at end of file
......@@ -42,24 +42,5 @@ class SumKernel : public framework::OpKernel<T> {
}
};
template <typename Place, typename T>
class SumGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>(framework::GradVarName("Out"));
auto outs = context.MultiOutput<Tensor>(framework::GradVarName("X"));
for (auto out : outs) {
out->mutable_data<T>(context.GetPlace());
}
auto place = context.GetEigenDevice<Place>();
auto in = EigenVector<T>::Flatten(*input);
for (auto out : outs) {
auto result = EigenVector<T>::Flatten(*out);
result.device(place) = in;
}
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册