面试题。
先写个O(n)的,用in order traversal, 用queue存m个,每次尾巴新进来一个,跟头做比较,因为queue里面的元素是in order的, 如果新尾巴比头更接近key,就把头挤掉。如果不比头更接近,说明已经完成,就return queue。
queue<node *> findmclosest(node * root, int key,  int m){
      queue<node *> q;
     findhelper(root, key, q, m);
     return q;
}
void findhelper(node *root, int key, queue<node *> q, int m){
     if(root == NULL) return;
     findhelper(root->left, key, q, m);
     if(q.size() < m) q.push(root->val);
     else{
         int diff = abs(root->val - key); 
         if(diff >= abs(q.front() - key) ) return;
         else{
            q.pop();
            q.push(root); 
        }
     } 
     findhelper(root->right, key, q, m); 
}
接下来是复杂一点的O(m*log n),思路是先找到key,一边找一边用两个m sized queue存经过的节点,
pre: 存放比key小的节点 
next:存放比key大的节点
pre:每次pop元素后,要考虑该元素的左子树, next:每次pop后,要考虑该元素的右子树
用logn的时间找到节点后,注意到pre和next都是尾巴上的点跟key值最接近,就直接从尾巴一个一个取就好了。
 //push next(min element of right subtree) and push pre(max element of left subtree)
void nextpush(node *tmp, deque<node *> &next)
while(tmp){
     if(next.size() >= m) next.pop_front();
     next.push(tmp);
     tmp = tmp->left;
   } 
} 
void prepush(node *tmp, deque<node *> &pre)
while(tmp){
     if(pre.size() >= m) pre.pop_front();
     pre.push(tmp);
     tmp = tmp->right;
   } 
}
//find key node
void findkey(node *root, int key, deque<node *> &pre, deque<node *> &next, int m){
  if(!root) return; 
  if(root->val <= key){
    if(pre.size() >= m) pre.pop();
    pre.push(root);  
    findkey(root->right, key, pre, next, m);
  }
 if(root->val >= key){
   if(next.size() >= m) next.pop();
    next.push(root);  
    findkey(root->left, key, pre, next, m); } 
} 
vector<int> findmclosest(node *root, int key, int m){
  int i = 0; 
  vector<int> res; 
  deque<node *> pre;
  deque<node *>next;
  node *tmp; 
  findkey(root, key, pre,next, m);
//deal with key found in tree 
 if(!pre.empty() && !next.empty &&pre.back() == next.back()){
   res[i++] = key;
   tmp = pre.back()->left;
   pre.pop_back(); 
   prepush(tmp,pre);
   tmp = next.back()->right;
   next.pop_back(); 
   nextpush(tmp,next);
 }
//start comparing and setting up result nodes.
while(!pre.empty() && !next.empty()){
  int prenode = pre.back();
  int nextnode = next.back();
  if(key - prenode->val > nextnode->val - key ) {
      res[i++] = nextnode->val;
      if(i == m) return res;
      next.pop_back();
      nextpush(nextnode->right, next);
  }
 else{
     res[i++] = prenode->val;
      if(i == m) return res;
      pre.pop_back();      prepush(prenode->left, pre);
  }
}
//when pre/next used up before getting m nodes
while(!pre.empty()){
    res[i++] = prenode->val;
      if(i == m) return res;
      pre.pop_back();   
   prepush(prenode->left, pre);
}
while(!next.empty()){
    res[i++] = nextnode->val;
      if(i == m) return res;
      next.pop_back();   
   nextpush(nextnode->right, next);
}
 return res;
}
 
 
No comments:
Post a Comment