news 2026/4/25 13:04:59

# 一个Java老鸟的TensorFlow入门——从计算图到GradientTape

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
# 一个Java老鸟的TensorFlow入门——从计算图到GradientTape

一个Java老鸟的TensorFlow入门——从计算图到GradientTape

写了20年Java,突然要学TensorFlow,第一反应是:这东西怎么这么绕?TF 1.x的计算图、Session、placeholder,跟Java的思维方式完全不一样。后来TF 2.x出了GradientTape,终于顺畅了。这篇记录我从零开始学TensorFlow的过程,不是教程,是一个老程序员的踩坑笔记。


一、TF 1.x:先建图,再跑图

第一个程序:常量加法

importtensorflow.compat.v1astf tf.disable_eager_execution()a=tf.constant(3.0,name="node1")b=tf.constant(4.0,name="node2")c=tf.add(a,b)withtf.Session()assess:print(sess.run(c))# 7.0

Java程序员的困惑:为什么不直接3.0 + 4.0

因为TF 1.x的设计思路是"先画蓝图,再施工"。tf.constant(3.0)不是在算3.0,是在图里画了一个节点。tf.add(a, b)也不是在算加法,是在图里画了一条从a、b到c的边。直到sess.run(c),施工队才开始干活。

这种"声明式"编程在Java里也有类似的东西——SQL。你写SELECT * FROM t WHERE id = 1,也不是在执行,是在描述你想要什么,数据库引擎去执行。TF 1.x的计算图也是这个意思。

第二个程序:变量与累加

value1=tf.Variable(0.0)const1=tf.constant(1.0)sum1=tf.Variable(0.0)new_value1=tf.add(value1,const1)value1=value1.assign(new_value1)sum1=tf.assign_add(sum1,value1)sess=tf.Session()init=tf.global_variables_initializer()sess.run(init)foriinrange(10):result=sess.run([value1,sum1])print("第%d次, 累加:%d, 和:%d"%(i+1,result[0],result[1]))

这里有两点跟Java不一样:

  1. 变量要初始化——tf.global_variables_initializer(),不调这个,变量是空的。Java里int i = 0就完了,TF里要显式告诉Session"请初始化所有变量"。
  2. 赋值是操作不是语句——value1.assign(new_value1)返回的是一个操作节点,不是立刻赋值。得sess.run()才生效。

第三个程序:占位符(placeholder)

a=tf.placeholder(tf.float32,name="a")b=tf.placeholder(tf.float32,name="b")c=tf.add(a,b)d=tf.multiply(a,b)withtf.Session()assess:result=sess.run([c,d],feed_dict={a:[1.0,2.0,3.0],b:[4.0,5.0,6.0]})print(result[0])# [5.0, 7.0, 9.0]print(result[1])# [4.0, 10.0, 18.0]

placeholder就是方法的参数。先在图里留个坑,运行的时候用feed_dict填数据。Java程序员可以理解为接口定义——你声明了参数类型,调用时传具体值。

还顺手把计算图写到了TensorBoard日志:

writer=tf.summary.FileWriter("e:\\log",tf.get_default_graph())

打开TensorBoard可以看到可视化计算图——节点和边的拓扑结构。调试时很有用。


二、TF 1.x的痛点

学了三个例子之后,我感觉到几个不舒服的地方:

  1. 所有东西都得在图里——想打个中间变量的值?sess.run()。想看类型?图里没有运行时类型。
  2. 调试困难——图建好了,跑不了断点。出错了报错信息跟图节点名相关,不是Python代码行号。
  3. 代码啰嗦——建图、初始化、Session、feed_dict,干个加法要写一堆。

这不是TF的问题,是"声明式"编程的代价。SQL也有类似问题——复杂SQL调试起来也很难。


三、TF 2.x:终于像正常代码了

GradientTape做多项式回归

importtensorflowastfimportnumpyasnpimportmatplotlib.pyplotasplt np.random.seed(0)X=np.linspace(-1,1,100)Y=0.5*X**2+0.5*X+2+np.random.normal(0,0.05,(100,))X_train,Y_train=X[:70],Y[:70]X_test,Y_test=X[70:],Y[70:]W1=tf.Variable(np.random.randn())W2=tf.Variable(np.random.randn())b=tf.Variable(np.random.randn())deflinear_regression(x):returnW1*x**2+W2*x+b optimizer=tf.optimizers.SGD(learning_rate=0.01)forstepinrange(100):withtf.GradientTape()astape:pred=linear_regression(X_train)loss=tf.reduce_mean(tf.square(pred-Y_train))gradients=tape.gradient(loss,[W1,W2,b])optimizer.apply_gradients(zip(gradients,[W1,W2,b]))if(step+1)%20==0:print("Step: %i, loss: %f, W1: %f, W2: %f, b: %f"%(step+1,loss,W1.numpy(),W2.numpy(),b.numpy()))

对比TF 1.x,变化巨大:

  1. 不需要Session了——直接执行,像正常Python代码
  2. 不需要建图了——GradientTape自动记录前向计算过程
  3. 调试方便——W1.numpy()随时可以看值,不需要sess.run()
  4. 代码量少了一半

GradientTape的核心思想:用with tf.GradientTape() as tape包住前向计算,TF自动记录所有操作。然后tape.gradient(loss, [参数])自动求导。不需要手写反向传播,不需要理解链式法则的推导过程。

Java程序员可以类比:TF 1.x像JDBC(手动管理连接、Statement、ResultSet),TF 2.x像MyBatis(框架帮你搞定底层,你只写业务逻辑)。


四、Keras:加载现成数据集

fromkeras.api.datasetsimportmnist,imdb(train_images,train_labels),(test_images,test_labels)=mnist.load_data()print(train_images.shape)# (60000, 28, 28)(train_datas,train_labels),(_,_)=imdb.load_data()word_index=imdb.get_word_index()reverse_word_index=dict([(value,key)for(key,value)inword_index.items()])decode_view=''.join(reverse_word_index.get(i-3,'?')foriintrain_datas[3])print(decode_view)

Keras内置了常用数据集,mnist.load_data()直接下载手写数字,imdb.load_data()直接下载电影评论。IMDB的数据已经转成了词索引,通过word_index反查可以还原原始文本。

这一步没什么技术含量,但省了很多数据准备的时间。学习阶段用现成数据集,项目阶段用自己的数据——这个节奏是对的。


五、总结:一个Java老兵的TF学习路径

阶段我做了什么关键收获
TF 1.x常量建图、Session、run()理解"声明式"编程
TF 1.x变量Variable、assign、初始化变量是图的一部分
TF 1.x占位符placeholder、feed_dict参数化计算图
TF 2.x GradientTape自动求导、多项式回归终于像正常代码了
Keras数据集MNIST、IMDB加载数据准备的起点

最大的体会:如果你现在开始学TensorFlow,直接学TF 2.x。TF 1.x的计算图概念了解一下就行(很多老教程和老项目还在用),但写代码用2.x。GradientTape + Eager Execution,学习曲线平很多。

环境搭建我踩的坑

  • Python版本:用3.9-3.11,太新可能TF不支持
  • TensorFlow安装:pip install tensorflow,GPU版装tensorflow-gpu(需要CUDA和cuDNN,很折腾,学习阶段CPU够用)
  • 如果只是学基础,CPU版就行,MNIST和线性回归秒跑完

相关阅读:

  • *《一个46岁架构师的AI实战经验总结》*
  • *《老鸟的JVM理解——不是背出来的,是搬对象搬出来的》*
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/25 13:04:45

Kuberhealthy 多集群监控方案:跨环境统一监控的架构设计

Kuberhealthy 多集群监控方案:跨环境统一监控的架构设计 【免费下载链接】kuberhealthy A Kubernetes operator for running synthetic checks as pods. Works great with Prometheus! 项目地址: https://gitcode.com/gh_mirrors/ku/kuberhealthy Kuberhealt…

作者头像 李华
网站建设 2026/4/25 13:03:27

3分钟学会:用Speechless永久保存微博记忆的完整指南

3分钟学会:用Speechless永久保存微博记忆的完整指南 【免费下载链接】Speechless 把新浪微博的内容,导出成 PDF 文件进行备份的 Chrome Extension。 项目地址: https://gitcode.com/gh_mirrors/sp/Speechless 你是否曾担心那些记录生活点滴的微博…

作者头像 李华
网站建设 2026/4/25 13:03:23

Staytus数据库架构详解:MySQL数据模型与关系设计

Staytus数据库架构详解:MySQL数据模型与关系设计 【免费下载链接】staytus 💡 An open source solution for publishing the status of your services 项目地址: https://gitcode.com/gh_mirrors/st/staytus Staytus作为一款开源的服务状态发布解…

作者头像 李华
网站建设 2026/4/25 12:58:20

二叉树和表达式树的实现

二叉树的介绍二叉树是树这种数据结果的一种特殊情况,其每个节点的子节点树不能超过两个,二叉树差不多就是树中最常用的特殊结构了。二叉树的分类满二叉树国外定义:由度为0和2的结点构成的树,没有度为1的节点。国内定义&#xff1a…

作者头像 李华