mars.learn.contrib.tensorflow.gen_tensorflow_dataset#

mars.learn.contrib.tensorflow.gen_tensorflow_dataset(tensors, output_shapes: 可选[元组[整型, ...]] = None, output_types: 可选[元组[数据类型, ...]] = None, fetch_kwargs=None)[来源]#

将 Mars 数据类型转换为 tf.data.Dataset。请注意,这是基于 tensorflow 2.0 例如 ———– >>> # 将一个张量转换为 tf.data.Dataset。 >>> data = mt.tensor([[1, 2], [3, 4]]) >>> dataset = gen_tensorflow_dataset(data) >>> list(dataset.as_numpy_iterator()) [array([1, 2]), array([3, 4])] >>> dataset.element_spec TensorSpec(shape=(2,), dtype=tf.int64, name=None)

>>> # convert a tuple of tensors to tf.data.Dataset.
>>> data1 = mt.tensor([1, 2]); data2 = mt.tensor([3, 4]); data3 = mt.tensor([5, 6])
>>> dataset = gen_tensorflow_dataset((data1, data2, data3))
>>> list(dataset.as_numpy_iterator())
[(1, 3, 5), (2, 4, 6)]
Parameters
  • 张量 (Mars 数据类型Mars 数据类型 组成的元组) – 转换为 tf.data.dataset 的数据

  • output_shapes – 一个(嵌套的)tf.TensorShape 对象结构,对应于从 mars 对象中生成的每个元素的组成部分。

  • output_types – 一个对应于从mars对象生成的每个元素的组件的 tf.DType 对象的(嵌套)结构。

  • fetch_kwargs – mars对象的参数,用于执行fetch()操作。

Return type

tf.data.Dataset