Tuesday, April 23, 2013

Find nearest m nodes for a key in BST(C++ code)

面试题。
先写个O(n)的,用in order traversal, 用queue存m个,每次尾巴新进来一个,跟头做比较,因为queue里面的元素是in order的, 如果新尾巴比头更接近key,就把头挤掉。如果不比头更接近,说明已经完成,就return queue。

queue<node *> findmclosest(node * root, int key,  int m){
      queue<node *> q;
     findhelper(root, key, q, m);
     return q;
}

void findhelper(node *root, int key, queue<node *> q, int m){
     if(root == NULL) return;
     findhelper(root->left, key, q, m);
     if(q.size() < m) q.push(root->val);
     else{
         int diff = abs(root->val - key);
         if(diff >= abs(q.front() - key) ) return;
         else{
            q.pop();
            q.push(root);
        }
     }
     findhelper(root->right, key, q, m);
}


接下来是复杂一点的O(m*log n),思路是先找到key,一边找一边用两个m sized queue存经过的节点,
pre: 存放比key小的节点
next:存放比key大的节点
pre:每次pop元素后,要考虑该元素的左子树, next:每次pop后,要考虑该元素的右子树
用logn的时间找到节点后,注意到pre和next都是尾巴上的点跟key值最接近,就直接从尾巴一个一个取就好了。
 //push next(min element of right subtree) and push pre(max element of left subtree)
void nextpush(node *tmp, deque<node *> &next)
while(tmp){
     if(next.size() >= m) next.pop_front();
     next.push(tmp);
     tmp = tmp->left;
   }
}

void prepush(node *tmp, deque<node *> &pre)
while(tmp){
     if(pre.size() >= m) pre.pop_front();
     pre.push(tmp);
     tmp = tmp->right;
   }
}

//find key node
void findkey(node *root, int key, deque<node *> &pre, deque<node *> &next, int m){
  if(!root) return;
  if(root->val <= key){
    if(pre.size() >= m) pre.pop();
    pre.push(root); 
    findkey(root->right, key, pre, next, m);
  }
 if(root->val >= key){
   if(next.size() >= m) next.pop();
    next.push(root); 
    findkey(root->left, key, pre, next, m); }
}


vector<int> findmclosest(node *root, int key, int m){
  int i = 0;
  vector<int> res;
  deque<node *> pre;
  deque<node *>next;
  node *tmp;
  findkey(root, key, pre,next, m);
//deal with key found in tree
 if(!pre.empty() && !next.empty &&pre.back() == next.back()){
   res[i++] = key;
   tmp = pre.back()->left;
   pre.pop_back();
   prepush(tmp,pre);
   tmp = next.back()->right;
   next.pop_back();
   nextpush(tmp,next);
 }

//start comparing and setting up result nodes.
while(!pre.empty() && !next.empty()){
  int prenode = pre.back();
  int nextnode = next.back();
  if(key - prenode->val > nextnode->val - key ) {
      res[i++] = nextnode->val;
      if(i == m) return res;
      next.pop_back();
      nextpush(nextnode->right, next);
  }
 else{
     res[i++] = prenode->val;
      if(i == m) return res;
      pre.pop_back();      prepush(prenode->left, pre);
  }
}

//when pre/next used up before getting m nodes

while(!pre.empty()){
   res[i++] = prenode->val;
      if(i == m) return res;
      pre.pop_back();   
   prepush(prenode->left, pre);
}

while(!next.empty()){
   res[i++] = nextnode->val;
      if(i == m) return res;
      next.pop_back();   
   nextpush(nextnode->right, next);
}

 return res;

}

No comments:

Post a Comment