提交 bbd6e09c 编写于 作者: Y yangyaming

Using LoDTensor for output.

上级 308ce9ac
...@@ -34,8 +34,8 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel { ...@@ -34,8 +34,8 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "The tensor rank of X must be 2."); PADDLE_ENFORCE_EQ(x->dims().size(), 2, "The tensor rank of X must be 2.");
PADDLE_ENFORCE_EQ(x->dims()[1], 1, "The 2nd dimension of X must be 1."); PADDLE_ENFORCE_EQ(x->dims()[1], 1, "The 2nd dimension of X must be 1.");
context.Output<Tensor>("IntermediateVal")->Resize(x->dims()); context.Output<framework::LoDTensor>("IntermediateVal")->Resize(x->dims());
context.Output<Tensor>("Out")->Resize({x->dims()[0], 1}); context.Output<framework::LoDTensor>("Out")->Resize({x->dims()[0], 1});
} }
}; };
...@@ -80,7 +80,8 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel { ...@@ -80,7 +80,8 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
auto* y = context.Input<Tensor>("Y"); auto* y = context.Input<Tensor>("Y");
auto* intermediate_val = context.Input<Tensor>("IntermediateVal"); auto* intermediate_val = context.Input<Tensor>("IntermediateVal");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out")); auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X")); auto* x_grad =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_NOT_NULL(x, "X must be initialized."); PADDLE_ENFORCE_NOT_NULL(x, "X must be initialized.");
PADDLE_ENFORCE_NOT_NULL(y, "Y must be initialized."); PADDLE_ENFORCE_NOT_NULL(y, "Y must be initialized.");
......
...@@ -52,8 +52,8 @@ class ModifiedHuberLossKernel : public framework::OpKernel { ...@@ -52,8 +52,8 @@ class ModifiedHuberLossKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
auto* in1 = context.Input<Tensor>("Y"); auto* in1 = context.Input<Tensor>("Y");
auto* out0 = context.Output<Tensor>("IntermediateVal"); auto* out0 = context.Output<framework::LoDTensor>("IntermediateVal");
auto* out1 = context.Output<Tensor>("Out"); auto* out1 = context.Output<framework::LoDTensor>("Out");
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
out1->mutable_data<T>(context.GetPlace()); out1->mutable_data<T>(context.GetPlace());
...@@ -77,9 +77,11 @@ class ModifiedHuberLossGradCPUKernel : public framework::OpKernel { ...@@ -77,9 +77,11 @@ class ModifiedHuberLossGradCPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Y"); auto* in0 = context.Input<Tensor>("Y");
auto* in1 = context.Input<Tensor>("IntermediateVal"); auto* in1 = context.Input<framework::LoDTensor>("IntermediateVal");
auto* in2 = context.Input<Tensor>(framework::GradVarName("Out")); auto* in2 =
auto* out0 = context.Output<Tensor>(framework::GradVarName("X")); context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* out0 =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (out0) { if (out0) {
const T* y_ptr = in0->data<T>(); const T* y_ptr = in0->data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册