tensorflow 2.0 学习 Himmelblua函数求极值

Posted heze

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow 2.0 学习 Himmelblua函数求极值相关的知识,希望对你有一定的参考价值。

Himmelblua函数在(-6,6),(-6,6)的二维平面上求极值

函数的数学表达式:f(x, y) = (x**2 + y -11)**2 + (x + y**2 -7)**2; 如下图所示

技术图片

等高线如下图所示:

技术图片

代码如下:

# encoding: utf-8

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1 import ImageGrid


# Himmelblau function
def himmelblua(x):
    return (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 - 7) ** 2


# 产生三维数据
x = np.arange(-6, 6, 0.1)           # 创建等差数组,步长为0.1
y = np.arange(-6, 6, 0.1)
print(x, y range:, x.shape, y.shape)
X, Y = np.meshgrid(x, y)
print(X, Y maps:, X.shape, Y.shape)
Z = himmelblua([X, Y])
max = np.max(Z)
min = np.min(Z)

# 画三维图
fig = plt.figure(himmelblua)
ax = fig.gca(projection=3d)       # 设置3D坐标轴
ax.plot_surface(X, Y, Z)            # 3D曲面图
ax.view_init(60, -30)
ax.set_xlabel(x)
ax.set_ylabel(y)
plt.show()

# 画等高线图
N = np.arange(min, max, (max-min)/200)
fig = plt.figure(contour)
ct = plt.contour(Z, N, linewidth=2, cmap=mpl.cm.jet)        # 计算等高差
plt.clabel(ct, inline=True, fmt=%1.1f, fontsize=10)
plt.colorbar(ct)
plt.xlabel(x)
plt.ylabel(y)
plt.savefig(contour-himmelblua.png)
plt.show()

# 初始化参数
x = tf.constant([4., 0.])

# 寻找极小值数值解
for step in range(200):
    with tf.GradientTape() as tape:
        tape.watch([x])
        y = himmelblua(x)
    grads = tape.gradient(y, [x])[0]
    x -= 0.01*grads
    if step % 20 ==19:
        print(step {}: x={}, f(x)={}.format(step, x.numpy(), y.numpy()))

经过迭代后的值越来越精确,这里就不表了!

以上是关于tensorflow 2.0 学习 Himmelblua函数求极值的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow 2 / 2.0 入门教程实战案例

Tensorflow 2.0的自定义训练循环的学习率

Tensorflow 2.0

tensorflow 2.0 学习 tensorboard可视化功能认识

tensorflow 2.0 学习 反向传播代码逐步实现

tensorflow 2.0 学习 Himmelblua函数求极值