提交 5a4367bb 编写于 作者: Y Yang Yu

Update

上级 63e31507
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
namespace paddle {
namespace operators {
......@@ -26,12 +27,16 @@ class ActivationKernel
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Out = context.Output<framework::Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto& X = detail::Ref(context.Input<framework::Tensor>("X"),
"Cannot get input tensor X, variable name = %s",
context.op().Input("X"));
auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"),
"Cannot get output tensor Out, variable name = %s",
context.op().Output("Out"));
Out.mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(X);
auto out = framework::EigenVector<T>::Flatten(Out);
auto* place =
context.template device_context<DeviceContext>().eigen_device();
Functor functor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册