未验证 提交 dfec6762 编写于 作者: H Hongyu Liu 提交者: GitHub

expand op supprt tensor attribute (#17773)

* expand support tensor attribute; test=develop

* fix bug ; test=develop

* fix uni test bug; test=develop

* fix copy bug; test=develop

* refine expand_times default value; test=develop
上级 3b70f870
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/expand_op.h" #include "paddle/fluid/operators/expand_op.h"
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -30,9 +31,12 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -30,9 +31,12 @@ class ExpandOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
std::vector<int> expand_times =
ctx->Attrs().Get<std::vector<int>>("expand_times");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_times(x_dims.size(), -1);
if (!ctx->HasInputs("expand_times_tensor")) {
expand_times = ctx->Attrs().Get<std::vector<int>>("expand_times");
}
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(), PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
"The number of Attr(expand_times)'s value must be equal " "The number of Attr(expand_times)'s value must be equal "
...@@ -42,15 +46,11 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -42,15 +46,11 @@ class ExpandOp : public framework::OperatorWithKernel {
std::vector<int64_t> out_shape(x_dims.size()); std::vector<int64_t> out_shape(x_dims.size());
for (size_t i = 0; i < expand_times.size(); ++i) { for (size_t i = 0; i < expand_times.size(); ++i) {
PADDLE_ENFORCE_GE(expand_times[i], 1, if (x_dims[i] == -1 || expand_times[i] == -1) {
"Each value of Attr(expand_times) should not be " out_shape[i] = -1;
"less than 1."); } else {
out_shape[i] = x_dims[i] * expand_times[i]; out_shape[i] = x_dims[i] * expand_times[i];
} }
// set the first dim to -1 in compile time
if (!ctx->IsRuntime() && x_dims[0] < 0) {
out_shape[0] = x_dims[0];
} }
ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
...@@ -58,6 +58,23 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -58,6 +58,23 @@ class ExpandOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
} }
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "expand_times_tensor") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -66,6 +83,9 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,6 +83,9 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded."); "X is the input to be expanded.");
AddInput("expand_times_tensor", "(Tensor Tensor<int>), epxand times for X")
.AsDuplicable()
.AsDispensable();
AddOutput("Out", AddOutput("Out",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"The rank of Output(Out) have the same with Input(X). " "The rank of Output(Out) have the same with Input(X). "
...@@ -73,7 +93,8 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -73,7 +93,8 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
"to size of the corresponding dimension of Input(X) multiplying " "to size of the corresponding dimension of Input(X) multiplying "
"the corresponding value given by Attr(expand_times)."); "the corresponding value given by Attr(expand_times).");
AddAttr<std::vector<int>>("expand_times", AddAttr<std::vector<int>>("expand_times",
"Expand times number for each dimension."); "Expand times number for each dimension.")
.SetDefault({});
AddComment(R"DOC( AddComment(R"DOC(
Expand operator tiles the input by given times number. You should set times Expand operator tiles the input by given times number. You should set times
number for each dimension by providing attribute 'expand_times'. The rank of X number for each dimension by providing attribute 'expand_times'. The rank of X
...@@ -113,6 +134,7 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -113,6 +134,7 @@ class ExpandGradOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_times = std::vector<int> expand_times =
ctx->Attrs().Get<std::vector<int>>("expand_times"); ctx->Attrs().Get<std::vector<int>>("expand_times");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
size_t start_pos = 0u; size_t start_pos = 0u;
...@@ -137,6 +159,23 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -137,6 +159,23 @@ class ExpandGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "expand_times_tensor") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker { class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
...@@ -150,6 +189,7 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -150,6 +189,7 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
op->SetInput("X", Input("X")); op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetInput("expand_times_tensor", Input("expand_times_tensor"));
op->SetAttrMap(Attrs()); op->SetAttrMap(Attrs());
return op; return op;
} }
......
...@@ -48,6 +48,29 @@ limitations under the License. */ ...@@ -48,6 +48,29 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
inline std::vector<int> get_expand_times(
const framework::ExecutionContext& ctx) {
auto list_expand_times_tensor =
ctx.MultiInput<framework::Tensor>("expand_times_tensor");
if (list_expand_times_tensor.size() > 0) {
// get tensor from
std::vector<int> vec_epxand_times;
for (size_t i = 0; i < list_expand_times_tensor.size(); ++i) {
auto tensor = list_expand_times_tensor[i];
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_epxand_times.push_back(*temp.data<int32_t>());
} else {
vec_epxand_times.push_back(*tensor->data<int32_t>());
}
}
return vec_epxand_times;
} else {
return ctx.Attr<std::vector<int>>("expand_times");
}
}
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
...@@ -74,12 +97,21 @@ class ExpandKernel : public framework::OpKernel<T> { ...@@ -74,12 +97,21 @@ class ExpandKernel : public framework::OpKernel<T> {
template <int Rank> template <int Rank>
void Expand(const framework::ExecutionContext& context) const { void Expand(const framework::ExecutionContext& context) const {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
auto& expand_times = context.Attr<std::vector<int>>("expand_times");
auto in_dims = in0->dims();
auto expand_times = get_expand_times(context);
auto* out0 = context.Output<Tensor>("Out"); auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims; Eigen::DSizes<int, Rank> bcast_dims;
for (size_t i = 0; i < expand_times.size(); ++i) { for (size_t i = 0; i < expand_times.size(); ++i) {
bcast_dims[i] = expand_times[i]; bcast_dims[i] = expand_times[i];
} }
framework::DDim out_dims(in_dims);
for (size_t i = 0; i < expand_times.size(); ++i) {
out_dims[i] *= expand_times[i];
}
out0->Resize(out_dims);
auto x = EigenTensor<T, Rank>::From(*in0); auto x = EigenTensor<T, Rank>::From(*in0);
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0); auto y = EigenTensor<T, Rank>::From(*out0);
...@@ -94,7 +126,8 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -94,7 +126,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
auto& expand_times = context.Attr<std::vector<int>>("expand_times"); // auto& expand_times = context.Attr<std::vector<int>>("expand_times");
auto expand_times = get_expand_times(context);
auto x_dims = in0->dims(); auto x_dims = in0->dims();
// 1. reshape_dims_vec is the broadcast parameter. For each dimension i, // 1. reshape_dims_vec is the broadcast parameter. For each dimension i,
// if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two // if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two
......
...@@ -28,7 +28,7 @@ from ..framework import Variable, OpProtoHolder, in_dygraph_mode ...@@ -28,7 +28,7 @@ from ..framework import Variable, OpProtoHolder, in_dygraph_mode
from ..dygraph import base from ..dygraph import base
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat, assign from .tensor import concat, assign, fill_constant
from . import utils from . import utils
from .. import unique_name from .. import unique_name
from functools import reduce from functools import reduce
...@@ -9329,11 +9329,38 @@ def expand(x, expand_times, name=None): ...@@ -9329,11 +9329,38 @@ def expand(x, expand_times, name=None):
helper = LayerHelper('expand', input=x, **locals()) helper = LayerHelper('expand', input=x, **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
# check expand_times have tensor
if in_dygraph_mode():
inputs = {'X': x}
attrs = {'expand_times': expand_times}
else:
def contain_tensor(expand_times):
for ele in expand_times:
if isinstance(ele, Variable):
return True
return False
if contain_tensor(expand_times):
new_expand_times = []
for ele in expand_times:
if isinstance(ele, Variable):
new_expand_times.append(ele)
else:
assert (isinstance(ele, int))
temp_out = helper.create_variable_for_type_inference(dtype)
fill_constant(
[1], 'int32', ele, force_cpu=True, out=temp_out)
new_expand_times.append(temp_out)
inputs = {'X': x, 'expand_times_tensor': new_expand_times}
attrs = {}
else:
inputs = {'X': x}
attrs = {'expand_times': expand_times}
helper.append_op( helper.append_op(
type='expand', type='expand', inputs=inputs, outputs={'Out': out}, attrs=attrs)
inputs={'X': x},
outputs={'Out': out},
attrs={'expand_times': expand_times})
return out return out
......
...@@ -34,6 +34,24 @@ class TestExpandOpRank1(OpTest): ...@@ -34,6 +34,24 @@ class TestExpandOpRank1(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestExpandOpRank1_tensor_attr(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {
'X': np.random.random(12).astype("float32"),
'expand_times_tensor': [('x1', np.ones((1)).astype('int32') * 2)]
}
self.attrs = {}
output = np.tile(self.inputs['X'], 2)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', no_grad_set=set('x1'))
class TestExpandOpRank2_Corner(OpTest): class TestExpandOpRank2_Corner(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
...@@ -49,6 +67,25 @@ class TestExpandOpRank2_Corner(OpTest): ...@@ -49,6 +67,25 @@ class TestExpandOpRank2_Corner(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestExpandOpRank2_Corner_tensor_attr(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {
'X': np.random.random((12, 14)).astype("float32"),
'expand_times_tensor': [('x1', np.ones((1)).astype('int32')),
('x2', np.ones((1)).astype('int32'))]
}
self.attrs = {}
output = np.tile(self.inputs['X'], (1, 1))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandOpRank2(OpTest): class TestExpandOpRank2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
...@@ -64,6 +101,25 @@ class TestExpandOpRank2(OpTest): ...@@ -64,6 +101,25 @@ class TestExpandOpRank2(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestExpandOpRank2_attr_tensor(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {
'X': np.random.random((12, 14)).astype("float32"),
'expand_times_tensor': [('x1', np.ones((1)).astype('int32') * 2),
('x2', np.ones((1)).astype('int32') * 3)]
}
self.attrs = {}
output = np.tile(self.inputs['X'], (2, 3))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandOpRank3_Corner(OpTest): class TestExpandOpRank3_Corner(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
......
...@@ -104,6 +104,7 @@ class TestInferShape(unittest.TestCase): ...@@ -104,6 +104,7 @@ class TestInferShape(unittest.TestCase):
sum_op_desc = block.append_op() sum_op_desc = block.append_op()
sum_op_desc.set_type("expand") sum_op_desc.set_type("expand")
sum_op_desc.set_input("X", ["x"]) sum_op_desc.set_input("X", ["x"])
sum_op_desc.set_input('expand_times_tensor', [])
sum_op_desc.set_output("Out", ["out"]) sum_op_desc.set_output("Out", ["out"])
sum_op_desc._set_attr('expand_times', expand_times) sum_op_desc._set_attr('expand_times', expand_times)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册