一个Java老鸟的TensorFlow入门——从计算图到GradientTape
写了20年Java,突然要学TensorFlow,第一反应是:这东西怎么这么绕?TF 1.x的计算图、Session、placeholder,跟Java的思维方式完全不一样。后来TF 2.x出了GradientTape,终于顺畅了。这篇记录我从零开始学TensorFlow的过程,不是教程,是一个老程序员的踩坑笔记。
一、TF 1.x:先建图,再跑图
第一个程序:常量加法
importtensorflow.compat.v1astf tf.disable_eager_execution()a=tf.constant(3.0,name="node1")b=tf.constant(4.0,name="node2")c=tf.add(a,b)withtf.Session()assess:print(sess.run(c))# 7.0Java程序员的困惑:为什么不直接3.0 + 4.0?
因为TF 1.x的设计思路是"先画蓝图,再施工"。tf.constant(3.0)不是在算3.0,是在图里画了一个节点。tf.add(a, b)也不是在算加法,是在图里画了一条从a、b到c的边。直到sess.run(c),施工队才开始干活。
这种"声明式"编程在Java里也有类似的东西——SQL。你写SELECT * FROM t WHERE id = 1,也不是在执行,是在描述你想要什么,数据库引擎去执行。TF 1.x的计算图也是这个意思。
第二个程序:变量与累加
value1=tf.Variable(0.0)const1=tf.constant(1.0)sum1=tf.Variable(0.0)new_value1=tf.add(value1,const1)value1=value1.assign(new_value1)sum1=tf.assign_add(sum1,value1)sess=tf.Session()init=tf.global_variables_initializer()sess.run(init)foriinrange(10):result=sess.run([value1,sum1])print("第%d次, 累加:%d, 和:%d"%(i+1,result[0],result[1]))这里有两点跟Java不一样:
- 变量要初始化——
tf.global_variables_initializer(),不调这个,变量是空的。Java里int i = 0就完了,TF里要显式告诉Session"请初始化所有变量"。 - 赋值是操作不是语句——
value1.assign(new_value1)返回的是一个操作节点,不是立刻赋值。得sess.run()才生效。
第三个程序:占位符(placeholder)
a=tf.placeholder(tf.float32,name="a")b=tf.placeholder(tf.float32,name="b")c=tf.add(a,b)d=tf.multiply(a,b)withtf.Session()assess:result=sess.run([c,d],feed_dict={a:[1.0,2.0,3.0],b:[4.0,5.0,6.0]})print(result[0])# [5.0, 7.0, 9.0]print(result[1])# [4.0, 10.0, 18.0]placeholder就是方法的参数。先在图里留个坑,运行的时候用feed_dict填数据。Java程序员可以理解为接口定义——你声明了参数类型,调用时传具体值。
还顺手把计算图写到了TensorBoard日志:
writer=tf.summary.FileWriter("e:\\log",tf.get_default_graph())打开TensorBoard可以看到可视化计算图——节点和边的拓扑结构。调试时很有用。
二、TF 1.x的痛点
学了三个例子之后,我感觉到几个不舒服的地方:
- 所有东西都得在图里——想打个中间变量的值?
sess.run()。想看类型?图里没有运行时类型。 - 调试困难——图建好了,跑不了断点。出错了报错信息跟图节点名相关,不是Python代码行号。
- 代码啰嗦——建图、初始化、Session、feed_dict,干个加法要写一堆。
这不是TF的问题,是"声明式"编程的代价。SQL也有类似问题——复杂SQL调试起来也很难。
三、TF 2.x:终于像正常代码了
GradientTape做多项式回归
importtensorflowastfimportnumpyasnpimportmatplotlib.pyplotasplt np.random.seed(0)X=np.linspace(-1,1,100)Y=0.5*X**2+0.5*X+2+np.random.normal(0,0.05,(100,))X_train,Y_train=X[:70],Y[:70]X_test,Y_test=X[70:],Y[70:]W1=tf.Variable(np.random.randn())W2=tf.Variable(np.random.randn())b=tf.Variable(np.random.randn())deflinear_regression(x):returnW1*x**2+W2*x+b optimizer=tf.optimizers.SGD(learning_rate=0.01)forstepinrange(100):withtf.GradientTape()astape:pred=linear_regression(X_train)loss=tf.reduce_mean(tf.square(pred-Y_train))gradients=tape.gradient(loss,[W1,W2,b])optimizer.apply_gradients(zip(gradients,[W1,W2,b]))if(step+1)%20==0:print("Step: %i, loss: %f, W1: %f, W2: %f, b: %f"%(step+1,loss,W1.numpy(),W2.numpy(),b.numpy()))对比TF 1.x,变化巨大:
- 不需要Session了——直接执行,像正常Python代码
- 不需要建图了——
GradientTape自动记录前向计算过程 - 调试方便——
W1.numpy()随时可以看值,不需要sess.run() - 代码量少了一半
GradientTape的核心思想:用with tf.GradientTape() as tape包住前向计算,TF自动记录所有操作。然后tape.gradient(loss, [参数])自动求导。不需要手写反向传播,不需要理解链式法则的推导过程。
Java程序员可以类比:TF 1.x像JDBC(手动管理连接、Statement、ResultSet),TF 2.x像MyBatis(框架帮你搞定底层,你只写业务逻辑)。
四、Keras:加载现成数据集
fromkeras.api.datasetsimportmnist,imdb(train_images,train_labels),(test_images,test_labels)=mnist.load_data()print(train_images.shape)# (60000, 28, 28)(train_datas,train_labels),(_,_)=imdb.load_data()word_index=imdb.get_word_index()reverse_word_index=dict([(value,key)for(key,value)inword_index.items()])decode_view=''.join(reverse_word_index.get(i-3,'?')foriintrain_datas[3])print(decode_view)Keras内置了常用数据集,mnist.load_data()直接下载手写数字,imdb.load_data()直接下载电影评论。IMDB的数据已经转成了词索引,通过word_index反查可以还原原始文本。
这一步没什么技术含量,但省了很多数据准备的时间。学习阶段用现成数据集,项目阶段用自己的数据——这个节奏是对的。
五、总结:一个Java老兵的TF学习路径
| 阶段 | 我做了什么 | 关键收获 |
|---|---|---|
| TF 1.x常量 | 建图、Session、run() | 理解"声明式"编程 |
| TF 1.x变量 | Variable、assign、初始化 | 变量是图的一部分 |
| TF 1.x占位符 | placeholder、feed_dict | 参数化计算图 |
| TF 2.x GradientTape | 自动求导、多项式回归 | 终于像正常代码了 |
| Keras数据集 | MNIST、IMDB加载 | 数据准备的起点 |
最大的体会:如果你现在开始学TensorFlow,直接学TF 2.x。TF 1.x的计算图概念了解一下就行(很多老教程和老项目还在用),但写代码用2.x。GradientTape + Eager Execution,学习曲线平很多。
环境搭建我踩的坑:
- Python版本:用3.9-3.11,太新可能TF不支持
- TensorFlow安装:
pip install tensorflow,GPU版装tensorflow-gpu(需要CUDA和cuDNN,很折腾,学习阶段CPU够用) - 如果只是学基础,CPU版就行,MNIST和线性回归秒跑完
相关阅读:
- *《一个46岁架构师的AI实战经验总结》*
- *《老鸟的JVM理解——不是背出来的,是搬对象搬出来的》*