提交 9445502f 编写于 作者: N nhzlx

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into trt_dy_lib

test=develop
...@@ -25,5 +25,6 @@ third_party/ ...@@ -25,5 +25,6 @@ third_party/
bazel-* bazel-*
third_party/ third_party/
build_*
# clion workspace. # clion workspace.
cmake-build-* cmake-build-*
...@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const std::vector<std::string> &Outputs( const std::vector<std::string> &Outputs(
const std::string &name) const override; const std::string &name) const override;
void ShareDim(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
const std::string &input_n = Inputs(in)[i];
const std::string &output_n = Outputs(out)[j];
PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@",
in, i);
PADDLE_ENFORCE(output_n != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(input_n);
auto *out_var = block_.FindVarRecursive(output_n);
PADDLE_ENFORCE(in_var->GetType() == out_var->GetType(),
"The type of %s and %s is not the same.", input_n, output_n);
SetDim(output_n, GetDim(input_n));
}
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(i, Inputs(in).size());
......
...@@ -542,6 +542,36 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -542,6 +542,36 @@ class RuntimeInferShapeContext : public InferShapeContext {
return op_.Outputs(name); return op_.Outputs(name);
} }
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
const std::string& input_n = Inputs(in)[i];
const std::string& output_n = Outputs(out)[j];
Variable* in_var = scope_.FindVar(input_n);
Variable* out_var = scope_.FindVar(output_n);
PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
"The type of %s and %s is not the same.", output_n,
GetDim(input_n));
if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<framework::LoDTensor>()) {
auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(
"Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows.");
}
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
const std::vector<std::string>& inputs = Inputs(in); const std::vector<std::string>& inputs = Inputs(in);
......
...@@ -56,6 +56,9 @@ class InferShapeContext { ...@@ -56,6 +56,9 @@ class InferShapeContext {
virtual const std::vector<std::string> &Outputs( virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0; const std::string &name) const = 0;
virtual void ShareDim(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) = 0;
virtual void ShareLoD(const std::string &in, const std::string &out, virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0; size_t i = 0, size_t j = 0) const = 0;
......
...@@ -80,7 +80,7 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -80,7 +80,7 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -91,12 +91,26 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -91,12 +91,26 @@ class ActivationOp : public framework::OperatorWithKernel {
} }
}; };
class ActivationOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
}
};
class ActivationOpGrad : public framework::OperatorWithKernel { class ActivationOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); ctx->ShareDim("Out", framework::GradVarName("X"));
ctx->ShareLoD("Out", framework::GradVarName("X"));
} }
protected: protected:
...@@ -525,12 +539,14 @@ namespace ops = paddle::operators; ...@@ -525,12 +539,14 @@ namespace ops = paddle::operators;
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker); \ ::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \ ::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
......
...@@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel { ...@@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel {
"-rank(Input(X)) (%d).", "-rank(Input(X)) (%d).",
axis, num_dims); axis, num_dims);
ctx->SetOutputDim("Out", in_dims); ctx->ShareDim("X", "Out");
ctx->SetOutputDim("Indices", in_dims); ctx->ShareDim("X", "Indices");
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices"); ctx->ShareLoD("X", "Indices");
} }
......
...@@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel { ...@@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
"The 2nd dimension of Input(Y) should be less than or " "The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X)."); "equal to the 2nd dimension of Input(X).");
ctx->SetOutputDim("Out", x_dims); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input."); "Rank of first input must >= rank of second input.");
ctx->SetOutputDim("Out", x_dim);
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference { ...@@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
auto& x = block->FindRecursiveOrCreateVar(x_name); auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name); auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType()); out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
} }
}; };
...@@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims); ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
} }
} }
...@@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { ...@@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); ctx->ShareDim(framework::GradVarName("Out"), /*->*/ x_grad_name);
ctx->SetOutputDim(x_grad_name, out_dims); ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ x_grad_name);
} }
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
auto y_dims = ctx->GetInputDim("Y");
ctx->SetOutputDim(y_grad_name, y_dims); ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
} }
} }
}; };
......
...@@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel { ...@@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
"Input(X) of FakeDequantizeMaxAbsOp should not be null."); "Input(X) of FakeDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeDequantizeMaxAbsOp should not be null."); "Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -137,6 +137,7 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { ...@@ -137,6 +137,7 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType());
} }
}; };
......
...@@ -49,7 +49,7 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -49,7 +49,7 @@ class PReluOp : public framework::OperatorWithKernel {
} else { } else {
PADDLE_THROW("Unkown mode %s", mode); PADDLE_THROW("Unkown mode %s", mode);
} }
ctx->SetOutputDim("Out", x_dim); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
......
...@@ -54,7 +54,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase { ...@@ -54,7 +54,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
"Input(X) of rnn_memory_helper op should not be null."); "Input(X) of rnn_memory_helper op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output of rnn_memory_helper op should not be null."); "Output of rnn_memory_helper op should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -90,8 +90,8 @@ class SequenceConvGradOp : public framework::OperatorWithKernel { ...@@ -90,8 +90,8 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("PaddingData")); ctx->GetInputDim("PaddingData"));
} }
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
if (ctx->HasOutput(framework::GradVarName("Filter"))) { if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), ctx->SetOutputDim(framework::GradVarName("Filter"),
......
...@@ -102,8 +102,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { ...@@ -102,8 +102,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
for (int64_t i = 1; i < og_dims.size(); ++i) { for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch."); PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch.");
} }
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
protected: protected:
......
...@@ -92,7 +92,7 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel { ...@@ -92,7 +92,7 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceReshapeGradOp should not be null."); "Input(X) of SequenceReshapeGradOp should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
}; };
......
...@@ -27,7 +27,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -27,7 +27,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
"Input(X) of SequenceSoftmaxOp should not be null."); "Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceSoftmaxOp should not be null."); "Output(Out) of SequenceSoftmaxOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
......
...@@ -151,9 +151,9 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase { ...@@ -151,9 +151,9 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X")); PADDLE_ENFORCE(context->HasInput("X"));
PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X"))); PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X")));
context->SetOutputDim(framework::GradVarName("X"),
context->GetInputDim("X")); context->ShareDim("X", /*->*/ framework::GradVarName("X"));
context->ShareLoD("X", framework::GradVarName("X")); context->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
}; };
......
...@@ -40,7 +40,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { ...@@ -40,7 +40,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
"The 2nd dimension of Input(X) and Input(Label) should " "The 2nd dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
ctx->SetOutputDim("Out", x_dims); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -620,7 +620,23 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -620,7 +620,23 @@ All parameter, weight, gradient are variables in Paddle.
// -- python binds for parallel executor. // -- python binds for parallel executor.
py::class_<ParallelExecutor> pe(m, "ParallelExecutor"); py::class_<ParallelExecutor> pe(m, "ParallelExecutor");
py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy"); py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy", R"DOC(
ExecutionStrategy allows the user to more preciously control how to run
the program in ParallelExecutor by setting the property.
The available properties include:
use_cuda (bool): Whether to use CUDA or not. Default True.
num_threads (int): The number of threads that used to run the
operators in ParallelExecutor. If it is not set, it will be
set in ParallelExecutor according to the device count.
Default 0.
allow_op_delay (bool): Whether to delay the communication operators
to run. Default False.
num_iteration_per_drop_scope (int): how many iterations between
the two dropping local scopes. Default 100.
)DOC");
exec_strategy.def(py::init()) exec_strategy.def(py::init())
.def_property( .def_property(
"num_threads", "num_threads",
...@@ -658,7 +674,25 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -658,7 +674,25 @@ All parameter, weight, gradient are variables in Paddle.
: ExecutionStrategy::kDefault; : ExecutionStrategy::kDefault;
}); });
py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy"); py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy", R"DOC(
BuildStrategy allows the user to more preciously control how to
build the SSA Graph in ParallelExecutor by setting the property.
The available properties include:
reduce_strategy (str): There are two reduce strategies, 'AllReduce'
and 'Reduce'. If you want that all parameters will be optimized
on all devices, you can choose 'AllReduce'; if you choose
'Reduce', all parameters will be evenly allocated to different
devices for optimization, and then broadcast the optimized
parameter to other devices. Default 'AllReduce'.
gradient_scale_strategy (str): There are two ways of defining loss@grad,
'CoeffNumDevice' and 'Customized'. By default, ParallelExecutor
sets the loss@grad according to the number of devices. If you want
to customize loss@grad, you can choose 'Customized'.
Default 'CoeffNumDevice'.
debug_graphviz_path (str): Whether to write the SSA Graph to file in the
form of graphviz. It is useful for debugging. Default "".
)DOC");
py::enum_<BuildStrategy::ReduceStrategy>(build_strategy, "ReduceStrategy") py::enum_<BuildStrategy::ReduceStrategy>(build_strategy, "ReduceStrategy")
.value("Reduce", BuildStrategy::ReduceStrategy::kReduce) .value("Reduce", BuildStrategy::ReduceStrategy::kReduce)
......
...@@ -74,7 +74,7 @@ def create_lod_tensor(data, recursive_seq_lens, place): ...@@ -74,7 +74,7 @@ def create_lod_tensor(data, recursive_seq_lens, place):
assert [ assert [
new_recursive_seq_lens new_recursive_seq_lens
] == recursive_seq_lens, "data and recursive_seq_lens do not match" ] == recursive_seq_lens, "data and recursive_seq_lens do not match"
flattened_data = np.concatenate(data, axis=0).astype("int64") flattened_data = np.concatenate(data, axis=0)
flattened_data = flattened_data.reshape([len(flattened_data), 1]) flattened_data = flattened_data.reshape([len(flattened_data), 1])
return create_lod_tensor(flattened_data, recursive_seq_lens, place) return create_lod_tensor(flattened_data, recursive_seq_lens, place)
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
......
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class ElementwiseMulOp(OpTest): class ElementwiseMulOp(OpTest):
...@@ -115,5 +117,56 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): ...@@ -115,5 +117,56 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
} }
class TestElementWiseMulSelectedRows(OpTest):
def setUp(self):
self.rows = [0, 1, 2, 3, 4, 5, 6]
self.feature = 12
self.height = 100
self.input_shape = (len(self.rows), self.feature)
def prepare_input(self, scope, place):
self.input = {
"X": np.random.random(self.input_shape).astype("float32"),
"Y": np.random.random(self.input_shape).astype("float32")
}
def init_input(in_name):
x_selected_rows = scope.var(in_name).get_selected_rows()
x_selected_rows.set_height(self.height)
x_selected_rows.set_rows(self.rows)
x_array = self.input[in_name]
x_tensor = x_selected_rows.get_tensor()
x_tensor.set(x_array, place)
init_input("X")
init_input("Y")
def create_out_selected_row(self, scope):
return scope.var('Out').get_selected_rows()
def check_result(self, out_selected_rows):
assert out_selected_rows.height() == self.height
assert out_selected_rows.rows() == self.rows
out_tensor = np.array(out_selected_rows.get_tensor())
assert out_tensor.shape == self.input_shape
def check_with_place(self, place):
scope = core.Scope()
self.prepare_input(scope, place)
out_selected_rows = self.create_out_selected_row(scope)
out_selected_rows.set_height(0)
out_selected_rows.set_rows([])
elementwise_mul = Operator("elementwise_mul", X='X', Y='Y', Out='Out')
elementwise_mul.run(scope, place)
self.check_result(out_selected_rows)
def test_elewisemul_with_selected_rows_input(self):
places = [core.CPUPlace()]
for place in places:
self.check_with_place(place)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册