未验证 提交 e1761709 编写于 作者: C chengduo 提交者: GitHub

Set the right shape of selected_rows (#13723)

* set the right shape of selected_rows
test=develop

* enhance check

* fix activation_op

* remove cast

* use ShareDimInfo replace SetDim and ShareLod

* use ShareDimAndLod
test=develop

* follow comment

test=develop

* check whether the input has lod
test=develop

* Split ShareDimAndLod

test=develop

* checkout clip.py
test=develop
上级 2a36f0a3
...@@ -25,5 +25,6 @@ third_party/ ...@@ -25,5 +25,6 @@ third_party/
bazel-* bazel-*
third_party/ third_party/
build_*
# clion workspace. # clion workspace.
cmake-build-* cmake-build-*
...@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const std::vector<std::string> &Outputs( const std::vector<std::string> &Outputs(
const std::string &name) const override; const std::string &name) const override;
void ShareDim(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
const std::string &input_n = Inputs(in)[i];
const std::string &output_n = Outputs(out)[j];
PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@",
in, i);
PADDLE_ENFORCE(output_n != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(input_n);
auto *out_var = block_.FindVarRecursive(output_n);
PADDLE_ENFORCE(in_var->GetType() == out_var->GetType(),
"The type of %s and %s is not the same.", input_n, output_n);
SetDim(output_n, GetDim(input_n));
}
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(i, Inputs(in).size());
......
...@@ -542,6 +542,36 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -542,6 +542,36 @@ class RuntimeInferShapeContext : public InferShapeContext {
return op_.Outputs(name); return op_.Outputs(name);
} }
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
const std::string& input_n = Inputs(in)[i];
const std::string& output_n = Outputs(out)[j];
Variable* in_var = scope_.FindVar(input_n);
Variable* out_var = scope_.FindVar(output_n);
PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
"The type of %s and %s is not the same.", output_n,
GetDim(input_n));
if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<framework::LoDTensor>()) {
auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(
"Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows.");
}
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
const std::vector<std::string>& inputs = Inputs(in); const std::vector<std::string>& inputs = Inputs(in);
......
...@@ -56,6 +56,9 @@ class InferShapeContext { ...@@ -56,6 +56,9 @@ class InferShapeContext {
virtual const std::vector<std::string> &Outputs( virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0; const std::string &name) const = 0;
virtual void ShareDim(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) = 0;
virtual void ShareLoD(const std::string &in, const std::string &out, virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0; size_t i = 0, size_t j = 0) const = 0;
......
...@@ -80,7 +80,7 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -80,7 +80,7 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -91,12 +91,26 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -91,12 +91,26 @@ class ActivationOp : public framework::OperatorWithKernel {
} }
}; };
class ActivationOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
}
};
class ActivationOpGrad : public framework::OperatorWithKernel { class ActivationOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); ctx->ShareDim("Out", framework::GradVarName("X"));
ctx->ShareLoD("Out", framework::GradVarName("X"));
} }
protected: protected:
...@@ -525,12 +539,14 @@ namespace ops = paddle::operators; ...@@ -525,12 +539,14 @@ namespace ops = paddle::operators;
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker); \ ::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \ ::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
......
...@@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel { ...@@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel {
"-rank(Input(X)) (%d).", "-rank(Input(X)) (%d).",
axis, num_dims); axis, num_dims);
ctx->SetOutputDim("Out", in_dims); ctx->ShareDim("X", "Out");
ctx->SetOutputDim("Indices", in_dims); ctx->ShareDim("X", "Indices");
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices"); ctx->ShareLoD("X", "Indices");
} }
......
...@@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel { ...@@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
"The 2nd dimension of Input(Y) should be less than or " "The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X)."); "equal to the 2nd dimension of Input(X).");
ctx->SetOutputDim("Out", x_dims); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input."); "Rank of first input must >= rank of second input.");
ctx->SetOutputDim("Out", x_dim);
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference { ...@@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
auto& x = block->FindRecursiveOrCreateVar(x_name); auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name); auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType()); out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
} }
}; };
...@@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims); ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
} }
} }
...@@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { ...@@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); ctx->ShareDim(framework::GradVarName("Out"), /*->*/ x_grad_name);
ctx->SetOutputDim(x_grad_name, out_dims); ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ x_grad_name);
} }
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
auto y_dims = ctx->GetInputDim("Y");
ctx->SetOutputDim(y_grad_name, y_dims); ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
} }
} }
}; };
......
...@@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel { ...@@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
"Input(X) of FakeDequantizeMaxAbsOp should not be null."); "Input(X) of FakeDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeDequantizeMaxAbsOp should not be null."); "Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -137,6 +137,7 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { ...@@ -137,6 +137,7 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType());
} }
}; };
......
...@@ -49,7 +49,7 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -49,7 +49,7 @@ class PReluOp : public framework::OperatorWithKernel {
} else { } else {
PADDLE_THROW("Unkown mode %s", mode); PADDLE_THROW("Unkown mode %s", mode);
} }
ctx->SetOutputDim("Out", x_dim); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
......
...@@ -54,7 +54,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase { ...@@ -54,7 +54,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
"Input(X) of rnn_memory_helper op should not be null."); "Input(X) of rnn_memory_helper op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output of rnn_memory_helper op should not be null."); "Output of rnn_memory_helper op should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -90,8 +90,8 @@ class SequenceConvGradOp : public framework::OperatorWithKernel { ...@@ -90,8 +90,8 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("PaddingData")); ctx->GetInputDim("PaddingData"));
} }
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
if (ctx->HasOutput(framework::GradVarName("Filter"))) { if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), ctx->SetOutputDim(framework::GradVarName("Filter"),
......
...@@ -102,8 +102,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { ...@@ -102,8 +102,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
for (int64_t i = 1; i < og_dims.size(); ++i) { for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch."); PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch.");
} }
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
protected: protected:
......
...@@ -92,7 +92,7 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel { ...@@ -92,7 +92,7 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceReshapeGradOp should not be null."); "Input(X) of SequenceReshapeGradOp should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
}; };
......
...@@ -27,7 +27,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -27,7 +27,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
"Input(X) of SequenceSoftmaxOp should not be null."); "Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceSoftmaxOp should not be null."); "Output(Out) of SequenceSoftmaxOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
......
...@@ -151,9 +151,9 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase { ...@@ -151,9 +151,9 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X")); PADDLE_ENFORCE(context->HasInput("X"));
PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X"))); PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X")));
context->SetOutputDim(framework::GradVarName("X"),
context->GetInputDim("X")); context->ShareDim("X", /*->*/ framework::GradVarName("X"));
context->ShareLoD("X", framework::GradVarName("X")); context->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
}; };
......
...@@ -40,7 +40,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { ...@@ -40,7 +40,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
"The 2nd dimension of Input(X) and Input(Label) should " "The 2nd dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
ctx->SetOutputDim("Out", x_dims); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class ElementwiseMulOp(OpTest): class ElementwiseMulOp(OpTest):
...@@ -115,5 +117,56 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): ...@@ -115,5 +117,56 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
} }
class TestElementWiseMulSelectedRows(OpTest):
def setUp(self):
self.rows = [0, 1, 2, 3, 4, 5, 6]
self.feature = 12
self.height = 100
self.input_shape = (len(self.rows), self.feature)
def prepare_input(self, scope, place):
self.input = {
"X": np.random.random(self.input_shape).astype("float32"),
"Y": np.random.random(self.input_shape).astype("float32")
}
def init_input(in_name):
x_selected_rows = scope.var(in_name).get_selected_rows()
x_selected_rows.set_height(self.height)
x_selected_rows.set_rows(self.rows)
x_array = self.input[in_name]
x_tensor = x_selected_rows.get_tensor()
x_tensor.set(x_array, place)
init_input("X")
init_input("Y")
def create_out_selected_row(self, scope):
return scope.var('Out').get_selected_rows()
def check_result(self, out_selected_rows):
assert out_selected_rows.height() == self.height
assert out_selected_rows.rows() == self.rows
out_tensor = np.array(out_selected_rows.get_tensor())
assert out_tensor.shape == self.input_shape
def check_with_place(self, place):
scope = core.Scope()
self.prepare_input(scope, place)
out_selected_rows = self.create_out_selected_row(scope)
out_selected_rows.set_height(0)
out_selected_rows.set_rows([])
elementwise_mul = Operator("elementwise_mul", X='X', Y='Y', Out='Out')
elementwise_mul.run(scope, place)
self.check_result(out_selected_rows)
def test_elewisemul_with_selected_rows_input(self):
places = [core.CPUPlace()]
for place in places:
self.check_with_place(place)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册