手撕正则表达式

时间:2022-12-22 19:07:40

我们先撕简单的。a ab a|b aa* a(a|b)* 先不管匹配任意字符的. 重复>=1次的+ [^0-9]除0-9外 \digit数字等。

正则表达式(regular expression, re)为啥叫表达式,不叫正则字符串之类?因为它是个表达式。3+5*2是个表达式;两个字符串可以有连接运算,如"a"+"b"或"a"."b"得到"ab"。

在正则表达式里,a,b,c就像2,3,5,是被运算的数,. | * ()是运算符。请注意:ab是a和b拼接,人们为了省事不把拼接运算符写出来。

(3+5)*2=16,3+(5*2)=13。如果没有四则运算优先级和括号,3+5*2等于16还是13?运算符后置(后缀表达式)没有歧义,例如35+2*是mul(add(3,5), 2),352+*是mul(3, add(5,2))。mul: multipy. What are infix, postfix and prefix expressions?

我们分3步走:

  1. 把a(a|b)*变成aab|*.这样的后缀表达式,40行程序
  2. 用Thompson算法把后缀表达式变成NFA,算14行程序
  3. 用NFA检查是否匹配,算11行程序

第1步中缀变后缀请看代码。

第2步后缀变NFA。NFA可以像积木一样拼起来。下面分别是a, a|b, a*的NFA:

手撕正则表达式  手撕正则表达式

手撕正则表达式

图片是用dot - graphviz version 2.49.0画的。如 dot -o ab.png -Tpng todot.txt 或 dot -Tpng todot.txt >ab.png 。dot -h看帮助。

拼接NFA的代码:

NFA postfix_to_nfa(const char* pfstr) {
  Stack<NFA>  stk;
  for (const char* p = pfstr; *p; p++) {
    switch (*p) {
    case '.': stk.push(stk.pop() + stk.pop()); break;
    case '|': stk.push(stk.pop() | stk.pop()); break;
    case '*': stk.push(*stk.pop()); break;
    default: stk.push(*p);
    }
  }
  NFA nfa = stk.pop();
  if (!stk.empty()) error;
  return nfa;
}

运算符函数也不长,含打印,匹配等全部代码180行:

// 从ChrisZZ(zchrissirhcz@gmail.com)的程序改来的
#include <stdio.h>
#include <string.h>
#include <string>
#include <stack>
using namespace std;

#define error throw __LINE__

template<class T>struct Stack : public stack<T> {
  T pop() { T t = top(); stack<T>::pop(); return t; }
};

const char END = '\0', EPSILON = '\001'; // Epsilon (upper case Ε, lower case ε): empty

struct State { // 像链表里的node
  int id; // 自动加1的编号
  State*  next[2];  // 到next[0]的边是epsilon;到next[1]的是char
  char  ch;
  State(int ch_=256, State* p1=0, State* p0=0) : id(_id++), ch(ch_) { next[0] = p0; next[1] = p1; }
  static int  _id;
  static char _visited[256];  // 下标是State的编号,仅print时用
};
int State::_id;
char  State::_visited[256];

struct NFA {
  State *start, *end;
  NFA() : start(0), end(0) {}
  NFA(char ch) { end = new State(END); start = new State(ch, end); }

  NFA operator + (NFA nfa) {
    end->ch = EPSILON; end->next[1] = nfa.start;
    end = nfa.end;
    return *this;
  }

  NFA operator | (NFA nfa) {
    State *head = new State(EPSILON, start, nfa.start), *tail = new State(END);
    end->ch = EPSILON; end->next[1] = tail;
    end = tail; start = head;
    nfa.end->ch = EPSILON; nfa.end->next[1] = tail;
    return *this;
  }

  NFA operator * () {
    State *tail = new State(END), *head = new State(EPSILON, start, tail);
    end->ch = EPSILON; end->next[0] = start; end->next[1] = tail;
    end = tail; start = head;
    return *this;
  }

  void print(const char* file_name);

  const char* elm; // point to the end of the longest match

  const char* match(const char* str) { elm = str; visit4m(start, str);  return elm; }

  void visit4p(const State* s, FILE* fp); // visit for print
  void visit4m(const State* s, const char* str); // visit for match
};

NFA postfix_to_nfa(const char* pfstr) {
  Stack<NFA>  stk;
  for (const char* p = pfstr; *p; p++) {
    switch (*p) {
    case '.': stk.push(stk.pop() + stk.pop()); break;
    case '|': stk.push(stk.pop() | stk.pop()); break;
    case '*': stk.push(*stk.pop()); break;
    default: stk.push(*p);
    }
  }
  NFA nfa = stk.pop();
  if (!stk.empty()) error;
  return nfa;
}

void NFA::print(const char* file_name) { // 同时输出到屏幕和DOT文件
  puts("");
  FILE* fp = fopen(file_name, "wt");
  if (!fp) return;
  fputs("digraph {\n\"\"\n", fp);
  fputs("[shape = plaintext]\n", fp);
  fputs("\trankdir = LR\n", fp);
  memset(State::_visited, 0, sizeof(State::_visited)), visit4p(start, fp);
  fputs("}", fp), fclose(fp);
}

void NFA::visit4p(const State* st, FILE* fp) {
  if (State::_visited[st->id]) return;
  State::_visited[st->id] = 1;
  for (int i = 0; i < 2; i++) {
    if (State* p = st->next[i]) {
      char  label[16];
      if (st->ch == EPSILON) strcpy(label, "''"); else sprintf(label, "'%c'", st->ch);
      // DOT支持不带BOM的UTF-8编码的文件。ε的UTF-8编码是\xce\xb5
      printf("%d - %s -> %d\n", st->id, label, p->id);
      fprintf(fp, "%d -> %d [label = <%s>]\n", st->id, p->id, label);
      visit4p(p, fp);
    }
  }
}

void NFA::visit4m(const State* st, const char* str) {
  if (st == end) {
    if (str > elm) elm = str;
    return;
  }
  for (int i = 0; i < 2; i++) {
    if (State* p = st->next[i]) {
      if (st->ch == EPSILON) visit4m(p, str);
      if (st->ch == *str) visit4m(p, str + 1);
    }
  }
}

struct CountOf {
  int opnd; // a是opnd b是opnd ab.也是opnd
  int or; // |
};

string re_to_postfix(const char* re) {
  string  out;
  CountOf cntof = { 0 };
  stack<CountOf>  khdz; // KuoHao (parenthesis) 的栈
  const char* p;
  for (p = re; *p; p++) { 
    switch (char c = *p) {
    case '(':
      if (cntof.opnd > 1) out += '.'; // a(???
      khdz.push(cntof);
      cntof.or = cntof.opnd = 0;
      break;
    case ')':
      if (cntof.opnd == 0 || khdz.empty()) error; // ) ()
      while (--cntof.opnd > 0) out += '.'; // ((a|b)(c|d)) =1时不进循环 1个opnd不需要.
      while (cntof.or-- > 0) out += '|'; // =1时进循环
      cntof = khdz.top(); khdz.pop();
      ++cntof.opnd; // 如遇到(时还没有opnd,遇到(a)的)时,知道了(a)是个opnd
      break;
    case '*':
      if (cntof.opnd ==0 ) error;
      out += c;
      break;
    case '|': // a|b变ab| a|b|c变ab|c| ab|c变ab.c|
      if (cntof.opnd == 0) error;
      while (--cntof.opnd > 0) out += '.';
      ++cntof.or;
      break;
    default: // a变a ab变ab. abc变ab.c.
      if (cntof.opnd > 1) { --cntof.opnd; out += '.'; }
      out += c; ++cntof.opnd;
    } // switch
    // printf("%*c", 5, ' ')输出5个空格
    printf("%*c%s %d %d %s\n", p - re, ' ', p, cntof.opnd, cntof.or, out.c_str());
  } // for
  if (!khdz.empty()) error;
  while (--cntof.opnd > 0) out += '.';
  while (cntof.or-- > 0)  out +=  '|';
  printf("%*c%s     %s\n", p - re, ' ', p, out.c_str());
  return out;
}

int main(){
  try {
    //const char* re = "a";
    //const char* re = "a*";
    //const char* re = "ab";
    //const char* re = "a|b";
    const char* re = "((a|b)(c|d))*";
    NFA nfa = postfix_to_nfa(re_to_postfix(re).c_str());
    nfa.print("todot.txt");
    const char* s = "bdabc";
    const char* p = nfa.match(s);
    printf("\nmatch: %.*s\n", p - s, s);
  }
  catch(int n) { printf("Error at line %d.\n", n); }
  getchar();
  return 0;
}
View Code

print和match都是递归遍历图。