mars.learn.contrib.tensorflow.gen_tensorflow_dataset#
- mars.learn.contrib.tensorflow.gen_tensorflow_dataset(tensors, output_shapes: 可选[元组[整型, ...]] = None, output_types: 可选[元组[数据类型, ...]] = None, fetch_kwargs=None)[来源]#
将 Mars 数据类型转换为 tf.data.Dataset。请注意,这是基于 tensorflow 2.0 例如 ———– >>> # 将一个张量转换为 tf.data.Dataset。 >>> data = mt.tensor([[1, 2], [3, 4]]) >>> dataset = gen_tensorflow_dataset(data) >>> list(dataset.as_numpy_iterator()) [array([1, 2]), array([3, 4])] >>> dataset.element_spec TensorSpec(shape=(2,), dtype=tf.int64, name=None)
>>> # convert a tuple of tensors to tf.data.Dataset. >>> data1 = mt.tensor([1, 2]); data2 = mt.tensor([3, 4]); data3 = mt.tensor([5, 6]) >>> dataset = gen_tensorflow_dataset((data1, data2, data3)) >>> list(dataset.as_numpy_iterator()) [(1, 3, 5), (2, 4, 6)]
- Parameters
张量 (Mars 数据类型 或 由 Mars 数据类型 组成的元组) – 转换为 tf.data.dataset 的数据
output_shapes – 一个(嵌套的)tf.TensorShape 对象结构,对应于从 mars 对象中生成的每个元素的组成部分。
output_types – 一个对应于从mars对象生成的每个元素的组件的 tf.DType 对象的(嵌套)结构。
fetch_kwargs – mars对象的参数,用于执行fetch()操作。
- Return type
tf.data.Dataset