未验证 提交 71400711 编写于 作者: D dzhwinter 提交者: GitHub

"exported scatter to python" (#9038)

* "exported scatter to python"

* Revert ""exported scatter to python""

This reverts commit 38745a62.

* "polish scatter and export to python"
上级 cf2addd2
...@@ -23,24 +23,24 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -23,24 +23,24 @@ class ScatterOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ref"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(Ref) of ScatterOp should not be null."); "Input(X) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Index"), PADDLE_ENFORCE(ctx->HasInput("Ids"),
"Input(Index) of ScatterOp should not be null."); "Input(Ids) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Updates"), PADDLE_ENFORCE(ctx->HasInput("Updates"),
"Input(Updates) of ScatterOp should not be null."); "Input(Updates) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ScatterOp should not be null."); "Output(Out) of ScatterOp should not be null.");
auto updates_dims = ctx->GetInputDim("Updates"); auto updates_dims = ctx->GetInputDim("Updates");
auto ref_dims = ctx->GetInputDim("Ref"); auto ref_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Index").size(), 1, PADDLE_ENFORCE_EQ(ctx->GetInputDim("Ids").size(), 1,
"Update Index should be 1-D."); "Update Ids should be 1-D.");
PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(), PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(),
"Reference and Updates should have the same shape size"); "Xerence and Updates should have the same shape size");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0], PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0],
ctx->GetInputDim("Index")[0], ctx->GetInputDim("Ids")[0],
"Updates and Index should have same batch-size."); "Updates and Ids should have same batch-size.");
framework::DDim data_dim(updates_dims); framework::DDim data_dim(updates_dims);
for (int i = 1; i < data_dim.size(); ++i) { for (int i = 1; i < data_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(data_dim[i], updates_dims[i]); PADDLE_ENFORCE_EQ(data_dim[i], updates_dims[i]);
...@@ -52,7 +52,7 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -52,7 +52,7 @@ class ScatterOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()), framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -64,14 +64,14 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -64,14 +64,14 @@ class ScatterGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Updates"), ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates")); ctx->GetInputDim("Updates"));
ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()), framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -80,9 +80,8 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -80,9 +80,8 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScatterOpMaker(OpProto* proto, OpAttrChecker* op_checker) ScatterOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ref", "The source input of scatter op"); AddInput("X", "The source input of scatter op");
AddInput("Index", AddInput("Ids", "The index input of scatter op where X will be updated");
"The index input of scatter op where Ref will be updated");
AddInput("Updates", "The updated value of updates op"); AddInput("Updates", "The updated value of updates op");
AddOutput("Out", "The output of add op"); AddOutput("Out", "The output of add op");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -91,8 +90,8 @@ Scatter Operator. ...@@ -91,8 +90,8 @@ Scatter Operator.
This operator obtains output by updating the input on selected indices on the first axis: This operator obtains output by updating the input on selected indices on the first axis:
$$ $$
Out = Ref \\ Out = X \\
Out[Index] = Ref[Index] + Updates Out[Ids] = X[Ids] + Updates
$$ $$
)DOC"); )DOC");
......
...@@ -25,14 +25,14 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -25,14 +25,14 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
auto *Ref = ctx.Input<Tensor>("Ref"); auto *X = ctx.Input<Tensor>("X");
auto *Index = ctx.Input<Tensor>("Index"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out"); auto *Out = ctx.Output<Tensor>("Out");
Out->ShareDataWith(*Ref); Out->ShareDataWith(*X);
GPUScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out); GPUScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
} }
}; };
...@@ -42,16 +42,16 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -42,16 +42,16 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dRef = dO // In place gradient: dX = dO
dRef->ShareDataWith(*dOut); dX->ShareDataWith(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Index] // Gradient by Gather: dUpdates = dO[Ids]
GPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates); GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
} }
}; };
......
...@@ -29,15 +29,15 @@ class ScatterOpKernel : public framework::OpKernel<T> { ...@@ -29,15 +29,15 @@ class ScatterOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU."); "This kernel only runs on CPU.");
auto *Ref = ctx.Input<Tensor>("Ref"); auto *X = ctx.Input<Tensor>("X");
auto *Index = ctx.Input<Tensor>("Index"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out"); auto *Out = ctx.Output<Tensor>("Out");
// In place output: Out = Ref, Out[Index] += Updates // In place output: Out = X, Out[Ids] += Updates
Out->ShareDataWith(*Ref); Out->ShareDataWith(*X);
// Apply ScatterUpdate: Out[index] += Updates[:] // Apply ScatterUpdate: Out[index] += Updates[:]
ScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out); ScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
} }
}; };
...@@ -47,16 +47,16 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -47,16 +47,16 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU."); "This kernel only runs on CPU.");
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dRef = dO // In place gradient: dX = dO
dRef->ShareDataWith(*dOut); dX->ShareDataWith(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index] // Gradient by Gather: dUpdates += dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates); CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
} }
}; };
......
...@@ -45,31 +45,13 @@ __activations__ = [ ...@@ -45,31 +45,13 @@ __activations__ = [
] ]
__all__ = [ __all__ = [
'mean', 'mean', 'mul', 'reshape', 'scale', 'sigmoid_cross_entropy_with_logits',
'mul', 'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul',
'reshape', 'elementwise_max', 'elementwise_min', 'elementwise_pow', 'clip',
'scale', 'clip_by_norm', 'softmax', 'sequence_softmax', 'logical_and', 'logical_or',
'sigmoid_cross_entropy_with_logits', 'logical_xor', 'logical_not', 'uniform_random',
'elementwise_add', 'uniform_random_batch_size_like', 'gaussian_random',
'elementwise_div', 'gaussian_random_batch_size_like', 'cumsum', 'scatter'
'elementwise_sub',
'elementwise_mul',
'elementwise_max',
'elementwise_min',
'elementwise_pow',
'clip',
'clip_by_norm',
'softmax',
'sequence_softmax',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
'uniform_random',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'cumsum',
] + __activations__ ] + __activations__
for _OP in set(__all__): for _OP in set(__all__):
......
...@@ -25,7 +25,7 @@ class TestScatterOp(OpTest): ...@@ -25,7 +25,7 @@ class TestScatterOp(OpTest):
updates_np = np.random.random((2, 3)).astype("float32") updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[index_np] = updates_np output_np[index_np] = updates_np
self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册