提交 bbd6e09c 编写于 作者: Y yangyaming

Using LoDTensor for output.

上级 308ce9ac
......@@ -34,8 +34,8 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "The tensor rank of X must be 2.");
PADDLE_ENFORCE_EQ(x->dims()[1], 1, "The 2nd dimension of X must be 1.");
context.Output<Tensor>("IntermediateVal")->Resize(x->dims());
context.Output<Tensor>("Out")->Resize({x->dims()[0], 1});
context.Output<framework::LoDTensor>("IntermediateVal")->Resize(x->dims());
context.Output<framework::LoDTensor>("Out")->Resize({x->dims()[0], 1});
}
};
......@@ -80,7 +80,8 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
auto* y = context.Input<Tensor>("Y");
auto* intermediate_val = context.Input<Tensor>("IntermediateVal");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
auto* x_grad =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_NOT_NULL(x, "X must be initialized.");
PADDLE_ENFORCE_NOT_NULL(y, "Y must be initialized.");
......
......@@ -52,8 +52,8 @@ class ModifiedHuberLossKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto* in1 = context.Input<Tensor>("Y");
auto* out0 = context.Output<Tensor>("IntermediateVal");
auto* out1 = context.Output<Tensor>("Out");
auto* out0 = context.Output<framework::LoDTensor>("IntermediateVal");
auto* out1 = context.Output<framework::LoDTensor>("Out");
out0->mutable_data<T>(context.GetPlace());
out1->mutable_data<T>(context.GetPlace());
......@@ -77,9 +77,11 @@ class ModifiedHuberLossGradCPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Y");
auto* in1 = context.Input<Tensor>("IntermediateVal");
auto* in2 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto* in1 = context.Input<framework::LoDTensor>("IntermediateVal");
auto* in2 =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* out0 =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (out0) {
const T* y_ptr = in0->data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册