提交 2c05465d 编写于 作者: Y Yu Yang

Fix unit-tests

上级 3a5693e0
...@@ -116,10 +116,13 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -116,10 +116,13 @@ class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {} void InferShape(framework::InferShapeContextBase* ctx) const override {}
DataType IndicateDataType(const ExecutionContext& ctx) const override {
return DataType::FP32;
}
}; };
template <typename T1, typename T2> template <typename T1, typename T2>
class CPUKernelTest : public OpKernel { class CPUKernelTest : public OpKernel<float> {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl; std::cout << "this is cpu kernel" << std::endl;
...@@ -146,7 +149,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker ...@@ -146,7 +149,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
} }
}; };
class CPUKernalMultiInputsTest : public OpKernel { class CPUKernalMultiInputsTest : public OpKernel<float> {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op().Inputs("xs"); auto xs = ctx.op().Inputs("xs");
......
...@@ -37,6 +37,11 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -37,6 +37,11 @@ class GatherOp : public framework::OperatorWithKernel {
output_dims[0] = batch_size; output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims); ctx->SetOutputDim("Out", output_dims);
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
}; };
class GatherGradOp : public framework::OperatorWithKernel { class GatherGradOp : public framework::OperatorWithKernel {
...@@ -47,6 +52,11 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -47,6 +52,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContextBase* ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
}; };
class GatherOpMaker : public framework::OpProtoAndCheckerMaker { class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -56,6 +56,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -56,6 +56,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
"dims can be one int or array. dims must be set."); "dims can be one int or array. dims must be set.");
ctx->SetOutputDim("Out", framework::make_ddim(temp)); ctx->SetOutputDim("Out", framework::make_ddim(temp));
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("data_type"));
}
}; };
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -76,6 +81,8 @@ Use to initialize tensor with gaussian random generator. ...@@ -76,6 +81,8 @@ Use to initialize tensor with gaussian random generator.
"Random seed of generator." "Random seed of generator."
"0 means use system wide seed") "0 means use system wide seed")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("data_type", "output data type")
.SetDefault(framework::DataType::FP32);
} }
}; };
......
...@@ -36,6 +36,11 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -36,6 +36,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out"); ctx->ShareLoD("Ids", /*->*/ "Out");
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type());
}
}; };
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -69,6 +74,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -69,6 +74,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W"); auto table_dims = ctx->GetInputDim("W");
ctx->SetOutputDim(framework::GradVarName("W"), table_dims); ctx->SetOutputDim(framework::GradVarName("W"), table_dims);
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -50,6 +50,11 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -50,6 +50,11 @@ class MultiplexOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", in_dim); ctx->SetOutputDim("Out", in_dim);
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type());
}
}; };
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -99,6 +104,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -99,6 +104,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputsDim(framework::GradVarName("X"), d_ins); ctx->SetOutputsDim(framework::GradVarName("X"), d_ins);
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -48,6 +48,11 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -48,6 +48,11 @@ class ScatterOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", ref_dims); ctx->SetOutputDim("Out", ref_dims);
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
}; };
class ScatterGradOp : public framework::OperatorWithKernel { class ScatterGradOp : public framework::OperatorWithKernel {
...@@ -60,6 +65,11 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -60,6 +65,11 @@ class ScatterGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("Updates")); ctx->GetInputDim("Updates"));
ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref")); ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref"));
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
}; };
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/softmax_with_cross_entropy_op.h" #include "paddle/operators/softmax_with_cross_entropy_op.h"
#include <paddle/function/TensorType.h>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -115,6 +116,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -115,6 +116,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
ctx->ShareLoD("Logits", /*->*/ "Softmax"); ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss"); ctx->ShareLoD("Logits", /*->*/ "Loss");
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Logits")->type());
}
}; };
class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
...@@ -149,6 +155,11 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -149,6 +155,11 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("Logits"), ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Softmax")); ctx->GetInputDim("Softmax"));
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Logits")->type());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -62,6 +62,11 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -62,6 +62,11 @@ class UniformRandomOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", framework::make_ddim(temp)); ctx->SetOutputDim("Out", framework::make_ddim(temp));
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("data_type"));
}
}; };
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -80,6 +85,8 @@ Used to initialize tensor with uniform random generator. ...@@ -80,6 +85,8 @@ Used to initialize tensor with uniform random generator.
"Random seed of uniform random. " "Random seed of uniform random. "
"0 means generate a seed by system") "0 means generate a seed by system")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("data_type", "output tensor data type")
.SetDefault(framework::DataType::FP32);
} }
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册