提交 29065d43 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #4503 from luotao1/remove_rowwise_add_op

remove rowwise_add_op
......@@ -100,7 +100,7 @@ class FCOp : public NetOp {
add_out = Output("AddOut");
AppendOp(framework::OpRegistry::CreateOp(
"rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}},
"elementwise_add", {{"X", {sum_out}}, {"Y", {Input("B")}}},
{{"Out", {add_out}}}, {}));
} else {
if (Output("AddOut") != framework::kEmptyVarName) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/rowwise_add_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class RowwiseAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of RowwiseAddOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("b"),
"Input(b) of RowwiseAddOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of RowwiseAddOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto b_dims = ctx->GetInputDim("b");
PADDLE_ENFORCE_GT(
x_dims.size(), b_dims.size(),
"The rank of input `X` must be larger than the one of input `b`.");
int num_col_dims = x_dims.size() - b_dims.size();
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
"The width of two operands must be same");
PADDLE_ENFORCE_EQ(ctx->Outputs("Out").size(), 1,
"The output size must be 1");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RowwiseAddOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix");
AddInput("b", "The right input of row-wise add op, must be vector");
AddOutput("Out", "The output of row-wise add op");
AddComment(R"DOC(Row-wise Add operator
for i in xrange(X.shape[0]):
Out = X[i] + b
)DOC");
}
};
class RowwiseAddGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "X should not be null");
PADDLE_ENFORCE(ctx->HasInput("b"), "b should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto b_dims = ctx->GetInputDim("b");
PADDLE_ENFORCE_GT(
x_dims.size(), b_dims.size(),
"The rank of input `X` must be larger than the one of input `b`.");
int64_t num_col_dims = x_dims.size() - b_dims.size();
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
"The width of two operands must be same");
auto x_grad_name = framework::GradVarName("X");
auto b_grad_name = framework::GradVarName("b");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(b_grad_name)) {
ctx->SetOutputDim(b_grad_name, b_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker,
rowwise_add_grad, ops::RowwiseAddGradOp);
REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowwiseAddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
rowwise_add_grad,
ops::RowwiseAddGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/rowwise_add_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
rowwise_add, ops::RowwiseAddKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
rowwise_add_grad,
ops::RowwiseAddGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class RowwiseAddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
int num_col_dims = context.Input<Tensor>("X")->dims().size() -
context.Input<Tensor>("b")->dims().size();
auto input =
EigenMatrix<T>::Reshape(*context.Input<Tensor>("X"), num_col_dims);
auto bias = EigenVector<T>::Flatten(*context.Input<Tensor>("b"));
auto output = EigenMatrix<T>::Reshape(*out, num_col_dims);
const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size;
Eigen::DSizes<int, 1> one_d(input.size());
Eigen::DSizes<int, 1> bcast(rest_size);
output.reshape(one_d).device(context.GetEigenDevice<Place>()) =
input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d);
}
};
template <typename Place, typename T>
class RowwiseAddGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* db = context.Output<Tensor>(framework::GradVarName("b"));
int num_col_dims = context.Input<Tensor>("X")->dims().size() -
context.Input<Tensor>("b")->dims().size();
auto out_grad = EigenMatrix<T>::Reshape(*dout, num_col_dims);
auto place = context.GetEigenDevice<Place>();
if (dx) {
dx->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::Reshape(*dx, num_col_dims).device(place) = out_grad;
}
if (db) {
db->mutable_data<T>(context.GetPlace());
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add
Eigen::array<int, 1> dims{{0}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = out_grad.sum(dims);
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestRowwiseAddOp(OpTest):
def setUp(self):
self.op_type = "rowwise_add"
self.inputs = {
'X': np.random.uniform(0.1, 1, [5, 10]).astype("float32"),
'b': np.random.uniform(0.1, 1, [10]).astype("float32")
}
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'b'], 'Out')
def test_check_grad_ingore_b(self):
self.check_grad(['X'], 'Out', no_grad_set=set('b'))
def test_check_grad_ingore_x(self):
self.check_grad(['b'], 'Out', no_grad_set=set('X'))
class TestRowwiseAddOp2(OpTest):
def setUp(self):
self.op_type = "rowwise_add"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 2, 5]).astype("float32"),
'b': np.random.uniform(0.1, 1, [2, 5]).astype("float32")
}
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'b'], 'Out')
def test_check_grad_ignore_b(self):
self.check_grad(['X'], 'Out', no_grad_set=set('b'))
def test_check_grad_ignore_x(self):
self.check_grad(['b'], 'Out', no_grad_set=set('X'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册