11. 数据集准备
现在是时候尝试应用我们的模型来解决一个简单的分类问题。为了检测模型是否能够顺利训练,下面我们将生成一个含有两个类的点集(如下图所示,两个类别的点分别用不同颜色表示),然后尝试训练模型来对这些点进行分类(二元分类问题)。
# number of samples in the data set N_SAMPLES = 1000 # ratio between training and test sets TEST_SIZE = 0.1
# we will use sklearn.make_moons() to generate the dataset: # - n_samples: 生成样本数量 # - noise: 高斯噪声 # - random_state: 生成随机种子,给定一个int型数据,能够保证每次生成数据相同 X, y = make_moons(n_samples = N_SAMPLES, noise=0.2, random_state=100) # split the dataset into training set (90%) & test set (10%) # - test_size: 如果是浮点数,则应该在0.0和1.0之间,表示要测试集占总数据集的比例;如果是int类型,表示测试集的绝对数量。 # - random_state: 随机数生成器使用的种子 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42)
下面我们来观察一下刚刚生成好的数据集。
shape of X: (1000, 2) X = [[ 2.24907069e-05 1.07275825e+00] [-5.59037377e-02 4.25241282e-01] [ 2.40944879e-02 4.08065802e-01] ... [ 1.75841594e+00 -5.77404262e-01] [ 1.26710180e+00 -4.42980152e-01] [-1.75927072e-01 5.83509936e-01]] shape of X: (900, 2)
# check the information in Y print("shape of Y: ", y.shape) # (1000,) print("number of zeros in Y: ", np.sum(y==0)) # 500 print("number of ones in Y: ", np.sum(y==1)) # 500 print("Y = ", y)
shape of Y: (1000,) number of zeros in Y: 500 number of ones in Y: 500 Y = [0 0 1 0 1 0 0 1 1 1 1 1 0 1 1 0 0 1 0 1 1 0 1 0 0 1 1 0 1 1 1 0 0 0 0 1 1 1 0 0 0 0 0 0 1 0 1 0 1 0 0 1 0 0 0 0 1 1 1 0 0 1 1 0 1 0 0 0 0 0 1 1 1 0 1 0 0 1 0 1 1 1 1 0 1 0 1 1 0 1 0 0 1 0 0 1 0 0 0 0 1 1 0 1 1 0 0 0 1 0 0 0 0 1 1 0 0 0 1 0 1 1 0 0 1 1 0 0 1 1 0 0 1 0 0 0 1 0 0 1 1 0 0 0 1 0 0 0 0 1 0 1 1 1 1 0 1 1 1 0 1 1 0 0 1 0 1 0 1 1 0 0 1 0 1 0 0 1 1 1 1 1 0 1 1 0 0 0 0 0 1 0 1 0 1 1 1 0 0 1 1 0 0 1 1 1 0 0 1 1 0 0 0 0 0 0 0 0 0 1 0 0 1 1 0 1 0 0 1 1 0 1 0 0 0 0 1 1 1 1 0 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 1 1 0 0 0 0 0 0 1 1 1 1 0 1 0 0 0 1 1 1 0 0 1 0 1 1 1 0 0 1 0 0 0 1 1 0 1 0 0 1 1 0 0 0 0 0 0 1 0 0 0 1 0 1 0 0 1 0 1 1 1 0 0 1 0 1 0 1 0 1 0 0 1 0 1 1 0 0 1 0 1 1 0 0 1 1 1 1 0 0 1 1 0 1 1 0 1 1 0 1 1 0 0 1 0 0 1 0 0 0 1 1 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 1 1 1 0 0 1 0 1 0 1 1 0 0 1 0 1 1 1 1 0 1 0 0 0 0 1 0 0 0 0 0 0 1 1 1 1 0 1 0 1 1 1 1 0 1 1 1 1 0 1 0 0 1 1 1 1 0 1 1 1 0 1 1 1 1 1 0 1 1 1 0 0 1 0 1 0 0 0 1 1 1 0 0 0 0 1 1 1 0 1 0 1 0 0 1 0 1 0 1 1 0 1 0 0 0 1 1 1 1 0 1 1 0 0 1 1 0 0 0 1 0 0 0 0 1 0 1 1 0 1 1 1 0 0 1 0 0 0 0 0 1 1 1 1 0 1 0 1 1 0 0 0 0 0 1 0 0 0 0 1 1 1 1 0 1 1 0 0 0 1 1 0 1 1 0 1 0 1 1 0 1 0 0 0 0 1 1 0 0 0 1 1 1 0 1 0 1 1 1 1 0 1 1 1 1 0 1 1 0 0 1 1 1 1 1 0 0 0 0 0 1 1 0 0 1 0 1 1 1 1 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 1 0 1 1 0 1 1 0 1 1 0 1 0 1 0 1 0 1 0 1 0 1 1 1 0 0 1 1 1 0 1 1 0 1 0 1 0 1 1 0 1 1 1 0 1 1 0 1 0 1 1 1 0 1 1 0 1 0 1 0 0 0 0 1 1 1 0 1 1 1 0 1 1 0 0 1 0 0 1 1 1 0 1 1 0 0 0 1 1 0 1 0 0 1 0 0 1 0 0 1 1 1 1 1 1 1 0 1 0 0 1 1 1 1 0 0 1 0 0 0 0 1 1 1 0 1 0 0 0 0 1 0 1 1 1 1 1 1 0 0 1 1 0 1 0 1 0 1 1 0 0 1 1 1 0 1 1 0 1 0 0 1 0 1 1 1 1 1 1 0 0 1 0 0 0 0 0 1 0 1 0 0 0 0 0 1 1 1 1 1 0 0 0 1 1 0 1 0 1 0 1 0 1 0 0 0 1 0 1 0 1 0 0 0 1 1 0 0 1 0 1 1 1 0 0 1 0 1 0 0 0 1 0 1 1 0 0 0 1 0 0 1 0 1 1 0 0 0 0 1 0 1 1 1 0 1 0 1 0 0 0 1 1 0 1 0 0 1 1 0 1 0 0 1 0 1 1 0 0 0 0 0 0 1 1 0 0 0 0 1 0 0 1 1 1 0 1 0 0 0 0 1 1 1 0 1 0 0 1 0 0 0 0 0 1 1 0 0 1 1 1 0 1 0 0 0 1 1 1 0 1 1 1 0 1 0 1 1 1 0 0 0 1 1 1 1 0 1 0 1 0 1 1 1 1 1 0 1 1 0 1 1 0 1 1 1 1 1]
# the function making up the graph of a dataset def make_plot(X, y, plot_name, file_name=None, XX=None, YY=None, preds=None, dark=False): if (dark): plt.style.use('dark_background') else: sns.set_style("whitegrid") plt.figure(figsize=(16,12)) axes = plt.gca() axes.set(xlabel="$X_1$", ylabel="$X_2$") plt.title(plot_name, fontsize=30) plt.subplots_adjust(left=0.20) plt.subplots_adjust(right=0.80) if(XX is not None and YY is not None and preds is not None): plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha = 1, cmap=cm.Spectral) plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap="Greys", vmin=0, vmax=.6) plt.scatter(X[:, 0], X[:, 1], c=y.ravel(), s=40, cmap=plt.cm.Spectral, edgecolors='black') if(file_name): plt.savefig(file_name) plt.close()
make_plot(X, y, "Dataset")
请同学们依次尝试将 sklearn.make_moons() 函数中 noise 取 0, 0.2, 0.4, 0.8, 1.0 等不同值,查看并截图保存对应的数据集图像和训练结果,分析其趋势,及出现这样趋势的原因。
然后,请同学们将 sklearn.make_moons() 中 noise 固定为 0.8,将sklearn.model_selection.train_test_split() 函数中的 TEST_SIZE 改为 0.98,观察和原先 noise 为 0.2、TEST_SIZE 为 0.1 的情况相比,训练集上的准确度和测试集上的准确度都分别发生了什么变化?出现这样变化的原因是什么?
调整 make_moons() 函数中的 noise 和 train_test_split() 函数中的 TEST_SIZE,改变数据集中点的分布。 将 sklearn.make_moons() 中 noise 固定为 0.8,将sklearn.model_selection.train_test_split() 函数中的 TEST_SIZE 改为 0.98,观察和原先 noise 为 0.2、TEST_SIZE 为 0.1 的情况相比,训练过程折线图表示如图所示,测试结果如图所示。
结论:当测试集占总数据量10%,noise等于0.2时,训练速度较慢,但是Acc和Cost曲线比较平滑,且测试效果和模型最终在训练集上得到的效果相差不大,即可以学习出一种较好的模型。而当测试集占总数据量98%,noise等于0.8时,训练速度较快,但是Acc曲线出现毛刺,且测试效果和模型最终在训练集上得到的效果很大,即出现过拟合现象。分析可知,当训练集占比过小,模型只能学习到很少的知识,无法得到好的模型。
12. 模型训练及测试
下面我们来调用 train 函数对模型进行训练。
# let's train the neural network params_values = train(X=np.transpose(X_train), Y=np.transpose(y_train.reshape((y_train.shape[0], 1))), nn_architecture=NN_ARCHITECTURE, epochs=10000, learning_rate=0.01)
Iteration: 00000 - cost: 0.69365 - accuracy: 0.50444 X.shape: (2, 900) Y_hat.shape: (1, 900) Y.shape: (1, 900) Iteration: 00050 - cost: 0.69349 - accuracy: 0.50444 Iteration: 00100 - cost: 0.69334 - accuracy: 0.50444 Iteration: 00150 - cost: 0.69319 - accuracy: 0.50444 Iteration: 00200 - cost: 0.69307 - accuracy: 0.50444 Iteration: 00250 - cost: 0.69295 - accuracy: 0.50444 Iteration: 00300 - cost: 0.69284 - accuracy: 0.50444 Iteration: 00350 - cost: 0.69272 - accuracy: 0.50444 Iteration: 00400 - cost: 0.69260 - accuracy: 0.50444 Iteration: 00450 - cost: 0.69249 - accuracy: 0.50444 Iteration: 00500 - cost: 0.69238 - accuracy: 0.50444 Iteration: 00550 - cost: 0.69228 - accuracy: 0.50444 Iteration: 00600 - cost: 0.69217 - accuracy: 0.50444 Iteration: 00650 - cost: 0.69206 - accuracy: 0.50444 Iteration: 00700 - cost: 0.69194 - accuracy: 0.50444 Iteration: 00750 - cost: 0.69182 - accuracy: 0.50444 Iteration: 00800 - cost: 0.69170 - accuracy: 0.50444 Iteration: 00850 - cost: 0.69156 - accuracy: 0.50444 Iteration: 00900 - cost: 0.69142 - accuracy: 0.50444 Iteration: 00950 - cost: 0.69126 - accuracy: 0.50444 Iteration: 01000 - cost: 0.69109 - accuracy: 0.50444 Iteration: 01050 - cost: 0.69090 - accuracy: 0.50444 Iteration: 01100 - cost: 0.69070 - accuracy: 0.50444 Iteration: 01150 - cost: 0.69049 - accuracy: 0.50444 Iteration: 01200 - cost: 0.69025 - accuracy: 0.50444 Iteration: 01250 - cost: 0.69000 - accuracy: 0.50889 Iteration: 01300 - cost: 0.68972 - accuracy: 0.52889 Iteration: 01350 - cost: 0.68941 - accuracy: 0.59778 Iteration: 01400 - cost: 0.68907 - accuracy: 0.66667 Iteration: 01450 - cost: 0.68869 - accuracy: 0.72222 Iteration: 01500 - cost: 0.68827 - accuracy: 0.76111 Iteration: 01550 - cost: 0.68780 - accuracy: 0.79111 Iteration: 01600 - cost: 0.68726 - accuracy: 0.81333 Iteration: 01650 - cost: 0.68666 - accuracy: 0.82778 Iteration: 01700 - cost: 0.68596 - accuracy: 0.83778 Iteration: 01750 - cost: 0.68512 - accuracy: 0.84000 Iteration: 01800 - cost: 0.68416 - accuracy: 0.84222 Iteration: 01850 - cost: 0.68308 - accuracy: 0.84444 Iteration: 01900 - cost: 0.68185 - accuracy: 0.84333 Iteration: 01950 - cost: 0.68042 - accuracy: 0.84222 Iteration: 02000 - cost: 0.67875 - accuracy: 0.84222 Iteration: 02050 - cost: 0.67680 - accuracy: 0.84444 Iteration: 02100 - cost: 0.67453 - accuracy: 0.84444 Iteration: 02150 - cost: 0.67183 - accuracy: 0.84667 Iteration: 02200 - cost: 0.66859 - accuracy: 0.84556 Iteration: 02250 - cost: 0.66471 - accuracy: 0.84222 Iteration: 02300 - cost: 0.66004 - accuracy: 0.84333 Iteration: 02350 - cost: 0.65437 - accuracy: 0.84222 Iteration: 02400 - cost: 0.64757 - accuracy: 0.84444 Iteration: 02450 - cost: 0.63942 - accuracy: 0.84778 Iteration: 02500 - cost: 0.62966 - accuracy: 0.84444 Iteration: 02550 - cost: 0.61796 - accuracy: 0.84111 Iteration: 02600 - cost: 0.60398 - accuracy: 0.84111 Iteration: 02650 - cost: 0.58764 - accuracy: 0.84111 Iteration: 02700 - cost: 0.56876 - accuracy: 0.84222 Iteration: 02750 - cost: 0.54730 - accuracy: 0.84222 Iteration: 02800 - cost: 0.52368 - accuracy: 0.85000 Iteration: 02850 - cost: 0.49867 - accuracy: 0.85333 Iteration: 02900 - cost: 0.47325 - accuracy: 0.85556 Iteration: 02950 - cost: 0.44840 - accuracy: 0.85556 Iteration: 03000 - cost: 0.42476 - accuracy: 0.85889 Iteration: 03050 - cost: 0.40263 - accuracy: 0.86333 Iteration: 03100 - cost: 0.38221 - accuracy: 0.86222 Iteration: 03150 - cost: 0.36367 - accuracy: 0.86778 Iteration: 03200 - cost: 0.34730 - accuracy: 0.87111 Iteration: 03250 - cost: 0.33327 - accuracy: 0.87444 Iteration: 03300 - cost: 0.32149 - accuracy: 0.87778 Iteration: 03350 - cost: 0.31175 - accuracy: 0.87889 Iteration: 03400 - cost: 0.30379 - accuracy: 0.88000 Iteration: 03450 - cost: 0.29733 - accuracy: 0.88111 Iteration: 03500 - cost: 0.29209 - accuracy: 0.88111 Iteration: 03550 - cost: 0.28783 - accuracy: 0.88111 Iteration: 03600 - cost: 0.28431 - accuracy: 0.88222 Iteration: 03650 - cost: 0.28133 - accuracy: 0.88333 Iteration: 03700 - cost: 0.27875 - accuracy: 0.88333 Iteration: 03750 - cost: 0.27648 - accuracy: 0.88333 Iteration: 03800 - cost: 0.27445 - accuracy: 0.88333 Iteration: 03850 - cost: 0.27262 - accuracy: 0.88222 Iteration: 03900 - cost: 0.27090 - accuracy: 0.88111 Iteration: 03950 - cost: 0.26930 - accuracy: 0.88000 Iteration: 04000 - cost: 0.26780 - accuracy: 0.88000 Iteration: 04050 - cost: 0.26634 - accuracy: 0.88000 Iteration: 04100 - cost: 0.26495 - accuracy: 0.88000 Iteration: 04150 - cost: 0.26356 - accuracy: 0.88000 Iteration: 04200 - cost: 0.26215 - accuracy: 0.87889 Iteration: 04250 - cost: 0.26074 - accuracy: 0.88000 Iteration: 04300 - cost: 0.25933 - accuracy: 0.88222 Iteration: 04350 - cost: 0.25793 - accuracy: 0.88333 Iteration: 04400 - cost: 0.25652 - accuracy: 0.88444 Iteration: 04450 - cost: 0.25510 - accuracy: 0.88444 Iteration: 04500 - cost: 0.25369 - accuracy: 0.88444 Iteration: 04550 - cost: 0.25227 - accuracy: 0.88333 Iteration: 04600 - cost: 0.25087 - accuracy: 0.88444 Iteration: 04650 - cost: 0.24944 - accuracy: 0.88556 Iteration: 04700 - cost: 0.24798 - accuracy: 0.88556 Iteration: 04750 - cost: 0.24650 - accuracy: 0.88667 Iteration: 04800 - cost: 0.24497 - accuracy: 0.88778 Iteration: 04850 - cost: 0.24336 - accuracy: 0.88778 Iteration: 04900 - cost: 0.24171 - accuracy: 0.88889 Iteration: 04950 - cost: 0.23999 - accuracy: 0.89000 Iteration: 05000 - cost: 0.23821 - accuracy: 0.89000 Iteration: 05050 - cost: 0.23635 - accuracy: 0.89222 Iteration: 05100 - cost: 0.23441 - accuracy: 0.89333 Iteration: 05150 - cost: 0.23237 - accuracy: 0.89333 Iteration: 05200 - cost: 0.23021 - accuracy: 0.89444 Iteration: 05250 - cost: 0.22792 - accuracy: 0.89556 Iteration: 05300 - cost: 0.22550 - accuracy: 0.89667 Iteration: 05350 - cost: 0.22292 - accuracy: 0.89667 Iteration: 05400 - cost: 0.22018 - accuracy: 0.89778 Iteration: 05450 - cost: 0.21728 - accuracy: 0.90000 Iteration: 05500 - cost: 0.21418 - accuracy: 0.90222 Iteration: 05550 - cost: 0.21087 - accuracy: 0.90444 Iteration: 05600 - cost: 0.20736 - accuracy: 0.90556 Iteration: 05650 - cost: 0.20364 - accuracy: 0.91111 Iteration: 05700 - cost: 0.19973 - accuracy: 0.91333 Iteration: 05750 - cost: 0.19562 - accuracy: 0.91444 Iteration: 05800 - cost: 0.19133 - accuracy: 0.91889 Iteration: 05850 - cost: 0.18686 - accuracy: 0.92222 Iteration: 05900 - cost: 0.18224 - accuracy: 0.92556 Iteration: 05950 - cost: 0.17747 - accuracy: 0.92778 Iteration: 06000 - cost: 0.17260 - accuracy: 0.93000 Iteration: 06050 - cost: 0.16767 - accuracy: 0.93333 Iteration: 06100 - cost: 0.16269 - accuracy: 0.93444 Iteration: 06150 - cost: 0.15775 - accuracy: 0.93778 Iteration: 06200 - cost: 0.15289 - accuracy: 0.93778 Iteration: 06250 - cost: 0.14812 - accuracy: 0.93889 Iteration: 06300 - cost: 0.14350 - accuracy: 0.94333 Iteration: 06350 - cost: 0.13907 - accuracy: 0.94444 Iteration: 06400 - cost: 0.13485 - accuracy: 0.94444 Iteration: 06450 - cost: 0.13086 - accuracy: 0.94556 Iteration: 06500 - cost: 0.12711 - accuracy: 0.94667 Iteration: 06550 - cost: 0.12361 - accuracy: 0.95000 Iteration: 06600 - cost: 0.12035 - accuracy: 0.95444 Iteration: 06650 - cost: 0.11733 - accuracy: 0.95778 Iteration: 06700 - cost: 0.11456 - accuracy: 0.95778 Iteration: 06750 - cost: 0.11200 - accuracy: 0.95889 Iteration: 06800 - cost: 0.10963 - accuracy: 0.96000 Iteration: 06850 - cost: 0.10745 - accuracy: 0.96000 Iteration: 06900 - cost: 0.10544 - accuracy: 0.96222 Iteration: 06950 - cost: 0.10359 - accuracy: 0.96111 Iteration: 07000 - cost: 0.10188 - accuracy: 0.96111 Iteration: 07050 - cost: 0.10031 - accuracy: 0.96222 Iteration: 07100 - cost: 0.09885 - accuracy: 0.96222 Iteration: 07150 - cost: 0.09750 - accuracy: 0.96222 Iteration: 07200 - cost: 0.09623 - accuracy: 0.96222 Iteration: 07250 - cost: 0.09506 - accuracy: 0.96444 Iteration: 07300 - cost: 0.09399 - accuracy: 0.96556 Iteration: 07350 - cost: 0.09298 - accuracy: 0.96556 Iteration: 07400 - cost: 0.09203 - accuracy: 0.96667 Iteration: 07450 - cost: 0.09118 - accuracy: 0.96667 Iteration: 07500 - cost: 0.09041 - accuracy: 0.96667 Iteration: 07550 - cost: 0.08969 - accuracy: 0.96667 Iteration: 07600 - cost: 0.08898 - accuracy: 0.96667 Iteration: 07650 - cost: 0.08831 - accuracy: 0.96667 Iteration: 07700 - cost: 0.08767 - accuracy: 0.96667 Iteration: 07750 - cost: 0.08707 - accuracy: 0.96667 Iteration: 07800 - cost: 0.08647 - accuracy: 0.96778 Iteration: 07850 - cost: 0.08594 - accuracy: 0.96667 Iteration: 07900 - cost: 0.08544 - accuracy: 0.96667 Iteration: 07950 - cost: 0.08497 - accuracy: 0.96667 Iteration: 08000 - cost: 0.08453 - accuracy: 0.96556 Iteration: 08050 - cost: 0.08412 - accuracy: 0.96667 Iteration: 08100 - cost: 0.08371 - accuracy: 0.96667 Iteration: 08150 - cost: 0.08332 - accuracy: 0.96889 Iteration: 08200 - cost: 0.08295 - accuracy: 0.96889 Iteration: 08250 - cost: 0.08259 - accuracy: 0.96889 Iteration: 08300 - cost: 0.08219 - accuracy: 0.96889 Iteration: 08350 - cost: 0.08180 - accuracy: 0.96778 Iteration: 08400 - cost: 0.08145 - accuracy: 0.96778 Iteration: 08450 - cost: 0.08114 - accuracy: 0.96778 Iteration: 08500 - cost: 0.08084 - accuracy: 0.96889 Iteration: 08550 - cost: 0.08055 - accuracy: 0.96889 Iteration: 08600 - cost: 0.08025 - accuracy: 0.97000 Iteration: 08650 - cost: 0.07996 - accuracy: 0.97000 Iteration: 08700 - cost: 0.07968 - accuracy: 0.97000 Iteration: 08750 - cost: 0.07939 - accuracy: 0.97000 Iteration: 08800 - cost: 0.07912 - accuracy: 0.96889 Iteration: 08850 - cost: 0.07885 - accuracy: 0.96889 Iteration: 08900 - cost: 0.07860 - accuracy: 0.96889 Iteration: 08950 - cost: 0.07836 - accuracy: 0.96889 Iteration: 09000 - cost: 0.07812 - accuracy: 0.96889 Iteration: 09050 - cost: 0.07788 - accuracy: 0.96889 Iteration: 09100 - cost: 0.07765 - accuracy: 0.96889 Iteration: 09150 - cost: 0.07743 - accuracy: 0.96889 Iteration: 09200 - cost: 0.07721 - accuracy: 0.96889 Iteration: 09250 - cost: 0.07698 - accuracy: 0.96889 Iteration: 09300 - cost: 0.07676 - accuracy: 0.96889 Iteration: 09350 - cost: 0.07653 - accuracy: 0.96889 Iteration: 09400 - cost: 0.07631 - accuracy: 0.96889 Iteration: 09450 - cost: 0.07610 - accuracy: 0.96889 Iteration: 09500 - cost: 0.07588 - accuracy: 0.96889 Iteration: 09550 - cost: 0.07568 - accuracy: 0.96889 Iteration: 09600 - cost: 0.07550 - accuracy: 0.96889 Iteration: 09650 - cost: 0.07532 - accuracy: 0.96889 Iteration: 09700 - cost: 0.07516 - accuracy: 0.96889 Iteration: 09750 - cost: 0.07500 - accuracy: 0.96889 Iteration: 09800 - cost: 0.07485 - accuracy: 0.96889 Iteration: 09850 - cost: 0.07470 - accuracy: 0.96889 Iteration: 09900 - cost: 0.07456 - accuracy: 0.96889 Iteration: 09950 - cost: 0.07442 - accuracy: 0.96889
调用一次 full_forward_propagation(),在测试集上评估训练好的模型。
# prediction Y_test_hat, _ = full_forward_propagation(np.transpose(X_test), params_values, NN_ARCHITECTURE)
# accuracy achieved on the test set acc_test = get_accuracy_value(Y_test_hat, np.transpose(y_test.reshape((y_test.shape[0], 1)))) print("Test set accuracy: {:.2f}".format(acc_test))
Test set accuracy: 0.98