提交 c5e28dd1 编写于 作者: Z zchen0211

scatter check in

上级 a8890110
......@@ -47,6 +47,7 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor)
op_library(gather_op SRCS gather_op.cc gather_op.cu)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
op_library(scatter_op SRCS scatter_op.cc scatter_op.cu)
cc_library(net_op SRCS net_op.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/scatter_op.h"
#include "paddle/framework/ddim.h"
namespace paddle {
namespace operators {
class ScatterOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
framework::DDim output_dims(ctx.Input<Tensor>("Ref")->dims());
ctx.Output<Tensor>("Out")->Resize(output_dims);
}
};
class ScatterGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto Updates_grad = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto Updates = ctx.Input<Tensor>("Updates");
auto Ref_grad = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto Ref = ctx.Input<Tensor>("Ref");
Ref_grad->Resize(Ref->dims());
Updates_grad->Resize(Updates->dims());
}
};
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScatterOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ref", "The source input of scatter op");
AddInput("Index",
"The index input of scatter op where Ref will be updated");
AddInput("Updates", "The updated value of updates op");
AddOutput("Out", "The output of add op");
AddComment(R"DOC(
Scatter Operator by selecting from the first axis,
Out = Ref
Out[Index] = Ref[Index] + Updates
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad,
ops::ScatterGradOp);
REGISTER_OP_CPU_KERNEL(scatter,
ops::ScatterOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
scatter_grad,
ops::ScatterGradientOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/scatter_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(scatter,
ops::ScatterOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "gather.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "scatter.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class ScatterOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *Ref = ctx.Input<Tensor>("Ref");
auto *Index = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out");
// In place output: Out = Ref, Out[Index] += Updates
Out->ShareDataWith<T>(*Ref);
// Apply ScatterUpdate: Out[index] += Updates[:]
ScatterUpdate<T>(ctx.GetPlace(), Updates, Index, Out);
}
};
template <typename Place, typename T>
class ScatterGradientOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index");
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dRef = dO
dRef->ShareDataWith<T>(*dO);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index]
Gather<T>(ctx.GetPlace(), dO, Index, dUpdates);
}
};
} // namespace operators
} // namespace paddle
......@@ -4,6 +4,7 @@ cc_library(paddle_pybind SHARED
DEPS pybind python backward
sgd_op
gather_op
scatter_op
add_op
mul_op
rowwise_add_op
......
......@@ -47,6 +47,7 @@ USE_OP(scale);
USE_OP_ITSELF(identity);
USE_OP(minus);
USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
namespace paddle {
namespace framework {
......
......@@ -14,6 +14,7 @@ py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
py_test(test_softmax_op SRCS test_softmax_op.py)
py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
py_test(test_gather_op SRCS test_gather_op.py)
py_test(test_scatter_op SRCS test_scatter_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(gradient_checker SRCS gradient_checker.py)
......
......@@ -21,12 +21,9 @@ class TestGatherOp(unittest.TestCase):
class TestGatherGradOp(GradientChecker):
def test_gather_grad(self):
print 'creating op'
op = create_op("gather")
print 'creating op done'
xnp = numpy.random.random((10, 20)).astype("float32")
inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")}
print 'correct before check gradient'
self.check_grad(op, inputs, set("X"), "Out")
......
import unittest
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
import numpy
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestScatterOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "scatter"
ref_np = numpy.ones((3, 3)).astype("float32")
index_np = numpy.array([1, 2]).astype("int32")
updates_np = numpy.random.random((2, 3)).astype("float32")
output_np = numpy.copy(ref_np)
output_np[index_np] += updates_np
self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
class TestScatterGradOp(GradientChecker):
def test_scatter_grad(self):
op = create_op("scatter")
# test data setup
ref_np = numpy.ones((3, 10)).astype("float32")
index_np = numpy.array([1, 2]).astype("int32")
updates_np = numpy.random.random((2, 10)).astype("float32")
output_np = numpy.copy(ref_np)
output_np[index_np] += updates_np
inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
# check gradient
self.check_grad(op, inputs, set(["Updates", "Ref"]), "Out")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册