From 2c05465d2f0f1134610364508fa73281fd44f1ad Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 27 Sep 2017 17:10:07 -0700 Subject: [PATCH] Fix unit-tests --- paddle/framework/operator_test.cc | 7 +++++-- paddle/operators/gather_op.cc | 10 ++++++++++ paddle/operators/gaussian_random_op.cc | 7 +++++++ paddle/operators/lookup_table_op.cc | 10 ++++++++++ paddle/operators/multiplex_op.cc | 10 ++++++++++ paddle/operators/scatter_op.cc | 10 ++++++++++ paddle/operators/softmax_with_cross_entropy_op.cc | 11 +++++++++++ paddle/operators/uniform_random_op.cc | 7 +++++++ 8 files changed, 70 insertions(+), 2 deletions(-) diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 8b4bb01a7bb..7f0ec90adef 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -116,10 +116,13 @@ class OpWithKernelTest : public OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase* ctx) const override {} + DataType IndicateDataType(const ExecutionContext& ctx) const override { + return DataType::FP32; + } }; template -class CPUKernelTest : public OpKernel { +class CPUKernelTest : public OpKernel { public: void Compute(const ExecutionContext& ctx) const { std::cout << "this is cpu kernel" << std::endl; @@ -146,7 +149,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker } }; -class CPUKernalMultiInputsTest : public OpKernel { +class CPUKernalMultiInputsTest : public OpKernel { public: void Compute(const ExecutionContext& ctx) const { auto xs = ctx.op().Inputs("xs"); diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 0e3cd174ade..da22bd0c52c 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -37,6 +37,11 @@ class GatherOp : public framework::OperatorWithKernel { output_dims[0] = batch_size; ctx->SetOutputDim("Out", output_dims); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; class GatherGradOp : public framework::OperatorWithKernel { @@ -47,6 +52,11 @@ class GatherGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContextBase* ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; class GatherOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index fc340c181cc..5cd2c7d2c06 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -56,6 +56,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel { "dims can be one int or array. dims must be set."); ctx->SetOutputDim("Out", framework::make_ddim(temp)); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("data_type")); + } }; class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { @@ -76,6 +81,8 @@ Use to initialize tensor with gaussian random generator. "Random seed of generator." "0 means use system wide seed") .SetDefault(0); + AddAttr("data_type", "output data type") + .SetDefault(framework::DataType::FP32); } }; diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 9b1314bfbad..929008fbcbe 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -36,6 +36,11 @@ class LookupTableOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->ShareLoD("Ids", /*->*/ "Out"); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("W")->type()); + } }; class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { @@ -69,6 +74,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { auto table_dims = ctx->GetInputDim("W"); ctx->SetOutputDim(framework::GradVarName("W"), table_dims); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("W")->type()); + } }; } // namespace operators diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 9896d269ccc..a069127a19a 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -50,6 +50,11 @@ class MultiplexOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", in_dim); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.MultiInput("X")[0]->type()); + } }; class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { @@ -99,6 +104,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel { } ctx->SetOutputsDim(framework::GradVarName("X"), d_ins); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.MultiInput("X")[0]->type()); + } }; } // namespace operators diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index 3fc4a39ebc5..619acfc8b62 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -48,6 +48,11 @@ class ScatterOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", ref_dims); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; class ScatterGradOp : public framework::OperatorWithKernel { @@ -60,6 +65,11 @@ class ScatterGradOp : public framework::OperatorWithKernel { ctx->GetInputDim("Updates")); ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref")); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index e2299b25445..de7c532421c 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/softmax_with_cross_entropy_op.h" +#include namespace paddle { namespace operators { @@ -115,6 +116,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ctx->ShareLoD("Logits", /*->*/ "Softmax"); ctx->ShareLoD("Logits", /*->*/ "Loss"); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("Logits")->type()); + } }; class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { @@ -149,6 +155,11 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("Logits"), ctx->GetInputDim("Softmax")); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("Logits")->type()); + } }; } // namespace operators diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 878d71802ab..97b1d0bed45 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -62,6 +62,11 @@ class UniformRandomOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", framework::make_ddim(temp)); } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("data_type")); + } }; class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { @@ -80,6 +85,8 @@ Used to initialize tensor with uniform random generator. "Random seed of uniform random. " "0 means generate a seed by system") .SetDefault(0); + AddAttr("data_type", "output tensor data type") + .SetDefault(framework::DataType::FP32); } }; } // namespace operators -- GitLab