提交 2ccaec4f 编写于 作者: Z zchen0211

gather scatter cond

上级 3d09a654
...@@ -126,8 +126,7 @@ void CondOp::PrepareDataForSubnet( ...@@ -126,8 +126,7 @@ void CondOp::PrepareDataForSubnet(
dim[0] = index_tensors[i].dims()[0]; dim[0] = index_tensors[i].dims()[0];
tensor_child->mutable_data<float>(dim, platform::CPUPlace()); tensor_child->mutable_data<float>(dim, platform::CPUPlace());
CPUGather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], CPUGather<float>(dev_ctx, tensor_parent, &index_tensors[i], tensor_child);
tensor_child);
} }
} }
...@@ -188,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, ...@@ -188,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
Variable* var_child = sub_scopes[i]->FindVar(output); Variable* var_child = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(var_child); PADDLE_ENFORCE_NOT_NULL(var_child);
auto* tensor_child = &var_child->Get<LoDTensor>(); auto* tensor_child = &var_child->Get<LoDTensor>();
ScatterAssign<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], ScatterAssign<float>(dev_ctx, tensor_child, &index_tensors[i],
tensor_parent); tensor_parent);
} }
} }
......
...@@ -32,11 +32,11 @@ namespace operators { ...@@ -32,11 +32,11 @@ namespace operators {
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void CPUGather(const platform::Place& place, void CPUGather(const platform::DeviceContext& ctx,
const paddle::framework::Tensor* src, const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index, const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) { paddle::framework::Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(place)); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1); PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0]; int index_size = index->dims()[0];
......
...@@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel<T> { ...@@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
CPUGather<T>(ctx.GetPlace(), x, index, output); CPUGather<T>(ctx.device_context(), x, index, output);
} }
}; };
...@@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
auto place = ctx.GetEigenDevice<platform::CPUPlace>(); auto place = ctx.GetEigenDevice<platform::CPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
ScatterAssign<T>(ctx.GetPlace(), dO, Index, dX); ScatterAssign<T>(ctx.device_context(), dO, Index, dX);
} }
}; };
......
...@@ -41,7 +41,9 @@ TEST(Gather, GatherData) { ...@@ -41,7 +41,9 @@ TEST(Gather, GatherData) {
int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace()); int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace());
CPUGather<int>(CPUPlace(), src, index, output); auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
CPUGather<int>(ctx, src, index, output);
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
......
...@@ -33,11 +33,11 @@ using Tensor = framework::Tensor; ...@@ -33,11 +33,11 @@ using Tensor = framework::Tensor;
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void ScatterAssign(const platform::Place& place, void ScatterAssign(const platform::DeviceContext& ctx,
const paddle::framework::Tensor* src, const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index, const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) { paddle::framework::Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(place)); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1); PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0]; int index_size = index->dims()[0];
......
...@@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel<T> {
// In place output: Out = Ref, Out[Index] += Updates // In place output: Out = Ref, Out[Index] += Updates
Out->ShareDataWith<T>(*Ref); Out->ShareDataWith<T>(*Ref);
// Apply ScatterUpdate: Out[index] += Updates[:] // Apply ScatterUpdate: Out[index] += Updates[:]
ScatterAssign<T>(ctx.GetPlace(), Updates, Index, Out); ScatterAssign<T>(ctx.device_context(), Updates, Index, Out);
} }
}; };
...@@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
dRef->ShareDataWith<T>(*dOut); dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index] // Gradient by Gather: dUpdates += dO[Index]
CPUGather<T>(ctx.GetPlace(), dOut, Index, dUpdates); CPUGather<T>(ctx.device_context(), dOut, Index, dUpdates);
} }
}; };
......
...@@ -40,7 +40,9 @@ TEST(scatter, ScatterUpdate) { ...@@ -40,7 +40,9 @@ TEST(scatter, ScatterUpdate) {
float* p_output = output->mutable_data<float>(make_ddim({4, 4}), CPUPlace()); float* p_output = output->mutable_data<float>(make_ddim({4, 4}), CPUPlace());
ScatterAssign<float>(CPUPlace(), src, index, output); auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
ScatterAssign<float>(ctx, src, index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0));
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册