引言

近来敲代码过程中遇到了一种情况,就是我希望直接去修改一个张量的数据指针,从而达到数据复用的目的,因为tensorflow是不会复用numpy数据的(torch的from_ numpy方法会)。虽然最后找到了更好的方法,不需要做这种极度危险的操作,但这个探索的过程颇为有意思,加之网上对Tensorflow的张量内存布局基本没有什么介绍,所以在此记录,以下所有内容是针对tf2.11的源码所讲。

对了,这个操作是Tensorflow.NET相关生态里面用到的,可以关注下项目谢谢喵。

张量内存布局

首先,C API暴露给上层的张量结构体和tensorflow库内部使用的张量结构体是有一点不同的,但也只是一层比较薄的封装,如下所示:

typedef struct TF_Tensor {
  tensorflow::AbstractTensorInterface* tensor;
} TF_Tensor;

对于我们的需求,我们的目的实际上是要通过一个TF_Tensor对象,我们根据内存布局,计算出数据指针所在的地址并进行修改,所以必须继续探索内部的内存布局。

一个TF_Tensor对象内部仅包含了一个AbstractTensorInterface类型的指针,而这个类的定义如下:

class AbstractTensorInterface {
 public:
  // Release any underlying resources, including the interface object.
  virtual void Release() = 0;

  // Returns tensor dtype.
  virtual DataType Type() const = 0;
  // Returns number of dimensions.
  virtual int NumDims() const = 0;
  // Returns size of specified dimension
  virtual int64_t Dim(int dim_index) const = 0;
  // Returns number of elements across all dimensions.
  virtual int64_t NumElements() const = 0;
  // Return size in bytes of the Tensor
  virtual size_t ByteSize() const = 0;
  // Returns a pointer to tensor data
  virtual void* Data() const = 0;

  // Returns if the tensor is aligned
  virtual bool IsAligned() const = 0;
  // Returns if their is sole ownership of this Tensor and thus it can be moved.
  virtual bool CanMove() const = 0;

  virtual std::string SummarizeValue() const = 0;

 protected:
  virtual ~AbstractTensorInterface() {}
};

可以看出,这是一个抽象类,内部包含纯虚函数,而它的实现之一(也是最常用的实现)如下:

class TensorInterface : public AbstractTensorInterface {
 public:
  TensorInterface() {}
  explicit TensorInterface(tensorflow::Tensor t) : tensor_(std::move(t)) {}
  ~TensorInterface() override {}

  void Release() override;

  DataType Type() const override;
  int NumDims() const override;
  int64_t Dim(int dim_index) const override;
  int64_t NumElements() const override;
  size_t ByteSize() const override;
  void* Data() const override;
  bool IsAligned() const override;
  bool CanMove() const override;
  std::string SummarizeValue() const override;

  void SetShape(const int64_t* dims, int num_dims);
  Status ToTensor(tensorflow::Tensor* dst) const;
  Status BitcastFrom(const TensorInterface& from, DataType type,
                     const int64_t* new_dims, int num_new_dims);
  Status FromProto(const tensorflow::TensorProto& from);

  tensorflow::Tensor& Tensor() { return tensor_; }

 private:
  tensorflow::Tensor tensor_;
};

这时候就已经看到真正想要的东西了,也就是私有成员tensor_,这就是tensorflow库的内部进行计算等操作时实际使用的张量,它的定义比较庞大,但是只包含两个成员,其简化后的定义如下:

struct Tensor{
  TensorShape shape_;
  TensorBuffer* buf_;
}

其实看到这里,我是有一点奇怪的,因为数据类型哪去了?先按下不表,继续往后看TensorShape的定义:

class TensorShape : public TensorShapeBase<TensorShape>{
    ...
}

这并非一个空类,但是内部没有独有的成员,所以此处将方法的定义略去。它的基类是一个模板类,继承自TensorShapeRep,但是同样没有任何独有的成员,所以我们直接看最终的TensorShapeRep的定义:

class TensorShapeRep {
 public:
  ~TensorShapeRep();

  /// Copy the specified shape
  TensorShapeRep(const TensorShapeRep& b);
  void operator=(const TensorShapeRep& b);

  /// Move the specified shape.  After moving, `b` is safe for destruction and
  // can be reassigned into, but its dimensions and number of elements can be
  // nonsensical (e.g., negative dimension sizes, or number of elements not
  // properly recomputed).
  TensorShapeRep(TensorShapeRep&& b);
  void operator=(TensorShapeRep&& b);

  /// Clear a tensor shape, producing the scalar shape.
  void Clear();

  // Maximum number of dimensions in a tensor.
  // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
  static constexpr int MaxDimensions() { return 254; }

  /// \brief Returns the number of elements in the tensor.
  ///
  /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
  /// which uses `ptrdiff_t`.  For PartialTensorShape, -1 means not fully
  /// defined.
  int64_t num_elements() const { return num_elements_; }

  /// For error messages.
  std::string DebugString() const;
  static std::string DebugString(const TensorShapeProto& proto);

 protected:
  // Constructable only via TensorShapeBase
  TensorShapeRep() = default;

  void ClearAllButDataType();

  // We use 16 bytes to represent a TensorShape.  Because we need to
  // be able to support full 64-bit dimension sizes and an arbitrary
  // number of dimensions for a Tensor, but most tensor dimensions are
  // significantly smaller than 64 bits and most tensors are 1, 2, or 3
  // dimensions, we have several representations.
  // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
  // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
  // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
  //        an out of line vector.
  // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
  // This value is not allowed in TensorShape either for format compatibility.
  struct Rep16 {
    uint16 dims_[6];
  };
  struct Rep32 {
    uint32 dims_[3];
  };
  struct Rep64 {
    gtl::InlinedVector<int64_t, 4>* dims_;
  };

  // We use the max value of uint16 or uint32 to represent unknown shapes, so
  // the maximum representable valid shape in these representations is one less.
  static constexpr int64_t kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
  static constexpr int64_t kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
  static constexpr uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
  static constexpr uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();

  Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
  Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
  Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }

  const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
  const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
  const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }

  enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };

  // Since we have a convenient extra byte available, we allow the
  // Tensor class to store an 8-bit value in this extra storage.  This
  // allows it to store the Tensor's datatype enum value here and avoid
  // an extra word of storage.
  friend class Tensor;
  friend class TensorShapeTestHelper;
  DataType data_type() const { return static_cast<DataType>(buf()[13]); }
  void set_data_type(DataType dt) {
    // We only have 8 bits available to store DataType, so make sure it fits
    DCHECK_LT(static_cast<uint32>(dt), 256u);
    buf()[13] = static_cast<uint8>(dt);
  }

  // We store the number of dimensions in byte 14, and the RepTag in byte 15.
  // Bytes [0..13] vary depending on the representation.
  // A value of 255 indicates unknown rank in the PartialTensorShape case.
  static constexpr uint8 kUnknownRank = 255;
  uint8 ndims_byte() const { return buf()[14]; }
  void set_ndims_byte(uint8 nd) { buf()[14] = nd; }

  RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
  void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }

  void set_num_elements(int64_t n) { num_elements_ = n; }

 private:
  void DestructorOutOfLine();
  void SlowCopyFrom(const TensorShapeRep& b);

  uint8* buf() { return &u_.buf[0]; }
  const uint8* buf() const { return &u_.buf[0]; }

  union {
    uint8 buf[16];
    // Force data to be aligned enough for a pointer.
    Rep64* unused_aligner;
  } u_;
  int64_t num_elements_;
};

首先,在最末尾处,我们已经可以判断出TensorShapeRep的尺寸,那就是24字节,但是union到底是做什么用的呢?这里就涉及到了标志位的问题,前面的数据类型为什么没有这一问题也可以得到解答。u_这个Union规定了16字节的大小,其中,第14个字节用来表示数据类型,第15个字节用来表示维度,第16个字节用来表示tag,而tag是一个规定如何表示形状的标志位,如果为0,那么就是Rep16的表示,每个维度上的形状使用两个字节表示,这样_u的前12个字节就可以表示6个维度的数据;如果是1,那么就是Rep32的标志,每个维度上的形状用四个字节表示;如果是2,那么就是Rep64的表示,每个维度要用8个字节来表示,这时候就已经不能存放在TensorShapeRep本身的内存布局中了,而是TensorShapeRep持有一个指向InlinedVector的指针,形状信息放在这个对象里面。

Anyway,我们这里已经得到了最宝贵的信息,那就是TensorShape无论如何,其本身内存布局只有24字节。接下来看TensorBuffer,这个类实际上是一个基类,在实际运行的时候实际上会是某个派生类对象,但由于c++的内存布局是基类靠前,并且TensorBuffer中已经可以找到数据指针,所以这里就不对派生类展开介绍。TensorBuffer的定义如下:

class TensorBuffer : public core::RefCounted {
 public:
  explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {}
  ~TensorBuffer() override {}

  /// \brief data() points to a memory region of size() bytes.
  ///
  /// NOTE(mrry): The `data()` method is not virtual for performance reasons.
  /// It can be called multiple times when the contents of a `Tensor` are
  /// accessed, and so making it non-virtual allows the body to be inlined.
  void* data() const { return data_; }

  /// \brief Size (in bytes) of the buffer.
  virtual size_t size() const = 0;

  /// \brief If this TensorBuffer is sub-buffer of another TensorBuffer,
  /// returns that TensorBuffer. Otherwise, returns this.
  virtual TensorBuffer* root_buffer() = 0;

  /// \brief Fills metadata about the allocation into the proto.
  virtual void FillAllocationDescription(
      AllocationDescription* proto) const = 0;

  virtual bool GetAllocatedBytes(size_t* out_bytes) const;

  /// \brief Helper method to reinterpret the buffer as an array of `T`.
  template <typename T>
  T* base() const {
    return reinterpret_cast<T*>(data());
  }

  /// \brief Whether this TensorBuffer owns the underlying memory.
  virtual bool OwnsMemory() const { return true; }

  /// \brief The type of the underlying memory.
  virtual AllocatorMemoryType GetMemoryType() const {
    return AllocatorMemoryType::kUnknown;
  }

 private:
  void* const data_;
};

这个类实际上是有引用计数的,方式就是继承自RefCounted这个类,这个类这里不展开详细介绍,只说结果:这个类在64位系统下的大小是16字节。这是因为类本身包含了一个int32类型的计数器,占4字节,然后类本身是有虚函数的,所以有一个大小为8字节的指向虚表的指针,然后根据64位系统下的对齐规则,会对齐为16字节。至此,张量的内存布局已经基本明朗。

通过偏移设置数据指针

假设我们现在(在其它语言中)持有了一个C API返回的TF_Tensor*指针tt_ptr,那么我们可以根据以下步骤来设置张量的数据指针:

  1. 对tt_ptr进行取值,获取TF_Tensor对象本身的起始地址tt_ad。
  2. 直接将tt_ad的前8个字节强制转换成一个指针,并且取出指针本身的值,这时候就是AbstractInterface这个对象的起始地址ai_ad。
  3. 对ai_ad施加大小为8字节的偏移,这时候就是Tensor对象的起始地址t_ad,这里施加一个偏移是因为AbstractInterface对象有虚函数,可以参见上文。
  4. 再次施加8*3字节的偏移,然后强制转换成指针并取出指针的值,就可以得到TensorBuffer的起始地址tb_ad。这里的偏移是因为跳过了TensorShape部分。
  5. 对tb_ad施加8*2字节的偏移,再取出接下来8字节所表示的值,这里是因为前面所讲的数据指针前有一个虚表指针和一个计数器。

经过上面这五步,我们就可以拿到数据指针并且修改它。

C API中提供了一个TF_TensorData函数获取数据指针,但需要注意,这个API只能获取到数据指针的值,并不能修改数据指针的指向。

结语

上面的整个过程都是自己对着源码推理出来的,自己并没有去修改源码来输出这些对象的内存布局,感觉还是挺有意思的。Tensorflow内部张量的内存布局和大多数AI框架的差别不大,比较巧妙的设计在于TensorShapeRep这里,但我觉得将数据类型放在TensorShape里面还是会有点别扭吧。

最后,本文所讲的直接修改数据指针的方法是一种非常危险的操作,如果不是万不得已,还是不要使用。