未验证 提交 39e129a1 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13794 from chengduoZH/set_right_shape_of_seleceted_rows_release

Set the output shape of seleceted rows in right way.
...@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const std::vector<std::string> &Outputs( const std::vector<std::string> &Outputs(
const std::string &name) const override; const std::string &name) const override;
void ShareDim(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
const std::string &input_n = Inputs(in)[i];
const std::string &output_n = Outputs(out)[j];
PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@",
in, i);
PADDLE_ENFORCE(output_n != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(input_n);
auto *out_var = block_.FindVarRecursive(output_n);
PADDLE_ENFORCE(in_var->GetType() == out_var->GetType(),
"The type of %s and %s is not the same.", input_n, output_n);
SetDim(output_n, GetDim(input_n));
}
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(i, Inputs(in).size());
......
...@@ -542,13 +542,45 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -542,13 +542,45 @@ class RuntimeInferShapeContext : public InferShapeContext {
return op_.Outputs(name); return op_.Outputs(name);
} }
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]); const std::string& input_n = Inputs(in)[i];
Variable* out_var = scope_.FindVar(Outputs(out)[j]); const std::string& output_n = Outputs(out)[j];
Variable* in_var = scope_.FindVar(input_n);
Variable* out_var = scope_.FindVar(output_n);
PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
"The type of %s and %s is not the same.", output_n,
GetDim(input_n));
if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<framework::LoDTensor>()) {
auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(
"Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows.");
}
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
const std::vector<std::string>& inputs = Inputs(in);
const std::vector<std::string>& outputs = Outputs(out);
PADDLE_ENFORCE_LT(i, inputs.size());
PADDLE_ENFORCE_LT(j, outputs.size());
Variable* in_var = scope_.FindVar(inputs.at(i));
if (!in_var->IsType<LoDTensor>()) return; if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = scope_.FindVar(outputs.at(j));
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(), PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out); "The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>(); auto in_tensor = in_var->Get<LoDTensor>();
......
...@@ -58,6 +58,9 @@ class InferShapeContext { ...@@ -58,6 +58,9 @@ class InferShapeContext {
void ShareLoDs(const std::string &in, const std::string &out) const; void ShareLoDs(const std::string &in, const std::string &out) const;
virtual void ShareDim(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) = 0;
virtual void ShareLoD(const std::string &in, const std::string &out, virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0; size_t i = 0, size_t j = 0) const = 0;
......
...@@ -79,7 +79,7 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -79,7 +79,7 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -90,12 +90,26 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -90,12 +90,26 @@ class ActivationOp : public framework::OperatorWithKernel {
} }
}; };
class ActivationOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
}
};
class ActivationOpGrad : public framework::OperatorWithKernel { class ActivationOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); ctx->ShareDim("Out", framework::GradVarName("X"));
ctx->ShareLoD("Out", framework::GradVarName("X"));
} }
protected: protected:
...@@ -524,12 +538,14 @@ namespace ops = paddle::operators; ...@@ -524,12 +538,14 @@ namespace ops = paddle::operators;
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker); \ ::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \ ::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
......
...@@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel { ...@@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel {
"-rank(Input(X)) (%d).", "-rank(Input(X)) (%d).",
axis, num_dims); axis, num_dims);
ctx->SetOutputDim("Out", in_dims); ctx->ShareDim("X", "Out");
ctx->SetOutputDim("Indices", in_dims); ctx->ShareDim("X", "Indices");
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices"); ctx->ShareLoD("X", "Indices");
} }
......
...@@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel { ...@@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
"The 2nd dimension of Input(Y) should be less than or " "The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X)."); "equal to the 2nd dimension of Input(X).");
ctx->SetOutputDim("Out", x_dims); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input."); "Rank of first input must >= rank of second input.");
ctx->SetOutputDim("Out", x_dim);
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference { ...@@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
auto& x = block->FindRecursiveOrCreateVar(x_name); auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name); auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType()); out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
} }
}; };
...@@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims); ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
} }
} }
...@@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { ...@@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); ctx->ShareDim(framework::GradVarName("Out"), /*->*/ x_grad_name);
ctx->SetOutputDim(x_grad_name, out_dims); ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ x_grad_name);
} }
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
auto y_dims = ctx->GetInputDim("Y");
ctx->SetOutputDim(y_grad_name, y_dims); ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
} }
} }
}; };
......
...@@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel { ...@@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
"Input(X) of FakeDequantizeMaxAbsOp should not be null."); "Input(X) of FakeDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeDequantizeMaxAbsOp should not be null."); "Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -49,7 +49,7 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -49,7 +49,7 @@ class PReluOp : public framework::OperatorWithKernel {
} else { } else {
PADDLE_THROW("Unkown mode %s", mode); PADDLE_THROW("Unkown mode %s", mode);
} }
ctx->SetOutputDim("Out", x_dim); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
......
...@@ -54,7 +54,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase { ...@@ -54,7 +54,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
"Input(X) of rnn_memory_helper op should not be null."); "Input(X) of rnn_memory_helper op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output of rnn_memory_helper op should not be null."); "Output of rnn_memory_helper op should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -90,8 +90,8 @@ class SequenceConvGradOp : public framework::OperatorWithKernel { ...@@ -90,8 +90,8 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("PaddingData")); ctx->GetInputDim("PaddingData"));
} }
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
if (ctx->HasOutput(framework::GradVarName("Filter"))) { if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), ctx->SetOutputDim(framework::GradVarName("Filter"),
......
...@@ -102,8 +102,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { ...@@ -102,8 +102,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
for (int64_t i = 1; i < og_dims.size(); ++i) { for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch."); PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch.");
} }
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
protected: protected:
......
...@@ -92,7 +92,7 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel { ...@@ -92,7 +92,7 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceReshapeGradOp should not be null."); "Input(X) of SequenceReshapeGradOp should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
}; };
......
...@@ -27,7 +27,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -27,7 +27,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
"Input(X) of SequenceSoftmaxOp should not be null."); "Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceSoftmaxOp should not be null."); "Output(Out) of SequenceSoftmaxOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
......
...@@ -151,9 +151,9 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase { ...@@ -151,9 +151,9 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X")); PADDLE_ENFORCE(context->HasInput("X"));
PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X"))); PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X")));
context->SetOutputDim(framework::GradVarName("X"),
context->GetInputDim("X")); context->ShareDim("X", /*->*/ framework::GradVarName("X"));
context->ShareLoD("X", framework::GradVarName("X")); context->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
}; };
......
...@@ -40,7 +40,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { ...@@ -40,7 +40,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
"The 2nd dimension of Input(X) and Input(Label) should " "The 2nd dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
ctx->SetOutputDim("Out", x_dims); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class ElementwiseMulOp(OpTest): class ElementwiseMulOp(OpTest):
...@@ -115,5 +117,56 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): ...@@ -115,5 +117,56 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
} }
class TestElementWiseMulSelectedRows(OpTest):
def setUp(self):
self.rows = [0, 1, 2, 3, 4, 5, 6]
self.feature = 12
self.height = 100
self.input_shape = (len(self.rows), self.feature)
def prepare_input(self, scope, place):
self.input = {
"X": np.random.random(self.input_shape).astype("float32"),
"Y": np.random.random(self.input_shape).astype("float32")
}
def init_input(in_name):
x_selected_rows = scope.var(in_name).get_selected_rows()
x_selected_rows.set_height(self.height)
x_selected_rows.set_rows(self.rows)
x_array = self.input[in_name]
x_tensor = x_selected_rows.get_tensor()
x_tensor.set(x_array, place)
init_input("X")
init_input("Y")
def create_out_selected_row(self, scope):
return scope.var('Out').get_selected_rows()
def check_result(self, out_selected_rows):
assert out_selected_rows.height() == self.height
assert out_selected_rows.rows() == self.rows
out_tensor = np.array(out_selected_rows.get_tensor())
assert out_tensor.shape == self.input_shape
def check_with_place(self, place):
scope = core.Scope()
self.prepare_input(scope, place)
out_selected_rows = self.create_out_selected_row(scope)
out_selected_rows.set_height(0)
out_selected_rows.set_rows([])
elementwise_mul = Operator("elementwise_mul", X='X', Y='Y', Out='Out')
elementwise_mul.run(scope, place)
self.check_result(out_selected_rows)
def test_elewisemul_with_selected_rows_input(self):
places = [core.CPUPlace()]
for place in places:
self.check_with_place(place)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册