在进行keras 网络计算时,有时候需要获取输入张量的维度来定义自己的层。但是由于keras是一个封闭的接口。因此在调用由于是张量不能直接用numpy 里的A.shape()。这样的形式来获取。这里需要调用一下keras 作为后端的方式来获取。当我们想要操作时第一时间就想到直接用 shape ()函数。其实keras 中真的有shape()这个函数。
shape(x)返回一个张量的符号shape,符号shape的意思是返回值本身也是一个tensor,
示例:
> from keras import backend as K > tf_session = K.get_session() > val = np.array([[1, 2], [3, 4]]) > kvar = K.variable(value=val) > input = keras.backend.placeholder(shape=(2, 4, 5)) > K.shape(kvar) <tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32> > K.shape(input) <tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32> __To get integer shape (Instead, you can use K.int_shape(x))__ > K.shape(kvar).eval(session=tf_session) array([2, 2], dtype=int32) > K.shape(input).eval(session=tf_session) array([2, 4, 5], dtype=int32)
如果直接调用这个出的不是我们想要的。我们想要的是tensor各个维度的大小。因此可以直接调用 int_shape(x) 函数。这个函数才是我们想要的。
> from keras import backend as K > input = K.placeholder(shape=(2, 4, 5)) > K.int_shape(input) (2, 4, 5) > val = np.array([[1, 2], [3, 4]]) > kvar = K.variable(value=val) > K.int_shape(kvar) (2, 2)
最后这样我们就可以直接调用里面的大小。然后定义我们自己的keras 层了。
补充知识:获取Tensor的维度(x.shape和x.get_shape()的区别)
tf.shape(a)和a.get_shape()比较
相同点:都可以得到tensor a的尺寸
不同点:tf.shape()中a 数据的类型可以是tensor, list, array
a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)
import tensorflow as tf import numpy as np x=tf.constant([[1,2,3],[4,5,6]]) y=[[1,2,3],[4,5,6]] z=np.arange(24).reshape([2,3,4]) sess=tf.Session() # tf.shape() x_shape=tf.shape(x) # x_shape 是一个tensor y_shape=tf.shape(y) # <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32> z_shape=tf.shape(z) # <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32> print(sess.run(x_shape)) # 结果:[2 3] print(sess.run(y_shape)) # 结果:[2 3] print(sess.run(z_shape) ) # 结果:[2 3 4] x_shape=x.get_shape() print(x_shape) # 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组 (2, 3) x_shape=x.get_shape().as_list() print(x_shape) # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3] 这是重点 返回列表方便参加其他代码的运算 # y_shape=y.get_shape() print(x_shape)# AttributeError: 'list' object has no attribute 'get_shape' # z_shape=z.get_shape() print(x_shape)# AttributeError: 'numpy.ndarray' object has no attribute 'get_shape' 或者a.shape.as_list()
以上这篇在keras 中获取张量 tensor 的维度大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
RTX 5090要首发 性能要翻倍!三星展示GDDR7显存
三星在GTC上展示了专为下一代游戏GPU设计的GDDR7内存。
首次推出的GDDR7内存模块密度为16GB,每个模块容量为2GB。其速度预设为32 Gbps(PAM3),但也可以降至28 Gbps,以提高产量和初始阶段的整体性能和成本效益。
据三星表示,GDDR7内存的能效将提高20%,同时工作电压仅为1.1V,低于标准的1.2V。通过采用更新的封装材料和优化的电路设计,使得在高速运行时的发热量降低,GDDR7的热阻比GDDR6降低了70%。
更新日志
- 魔兽世界奥卡兹岛地牢入口在哪里 奥卡兹岛地牢入口位置一览
- 和文军-丽江礼物[2007]FLAC
- 陈随意2012-今生的伴[豪记][WAV+CUE]
- 罗百吉.2018-我们都一样【乾坤唱片】【WAV+CUE】
- 《怪物猎人:荒野》不加中配请愿书引热议:跪久站不起来了?
- 《龙腾世纪4》IGN 9分!殿堂级RPG作品
- Twitch新规禁止皮套外露敏感部位 主播直接“真身”出镜
- 木吉他.1994-木吉他作品全集【滚石】【WAV+CUE】
- 莫华伦.2022-一起走过的日子【京文】【WAV+CUE】
- 曾淑勤.1989-装在袋子里的回忆【点将】【WAV+CUE】
- 滚石香港黄金十年系列《赵传精选》首版[WAV+CUE][1.1G]
- 雷婷《乡村情歌·清新民谣》1:1母盘直刻[低速原抓WAV+CUE][1.1G]
- 群星 《DJ夜色魅影HQⅡ》天艺唱片[WAV+CUE][1.1G]
- 群星《烧透你的耳朵2》DXD金佰利 [低速原抓WAV+CUE][1.3G]
- 群星《难忘的回忆精选4》宝丽金2CD[WAV+CUE][1.4G]