123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- # (C) Copyright 2007
- # Andreas Kloeckner <inform -at- tiker.net>
- #
- # Use, modification and distribution is subject to the Boost Software
- # License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
- # http://www.boost.org/LICENSE_1_0.txt)
- #
- # Authors: Andreas Kloeckner
- import boost.mpi as mpi
- import random
- import sys
- MAX_GENERATIONS = 20
- TAG_DEBUG = 0
- TAG_DATA = 1
- TAG_TERMINATE = 2
- TAG_PROGRESS_REPORT = 3
- class TagGroupListener:
- """Class to help listen for only a given set of tags.
- This is contrived: Typicallly you could just listen for
- mpi.any_tag and filter."""
- def __init__(self, comm, tags):
- self.tags = tags
- self.comm = comm
- self.active_requests = {}
- def wait(self):
- for tag in self.tags:
- if tag not in self.active_requests:
- self.active_requests[tag] = self.comm.irecv(tag=tag)
- requests = mpi.RequestList(self.active_requests.values())
- data, status, index = mpi.wait_any(requests)
- del self.active_requests[status.tag]
- return status, data
- def cancel(self):
- for r in self.active_requests.itervalues():
- r.cancel()
- #r.wait()
- self.active_requests = {}
- def rank0():
- sent_histories = (mpi.size-1)*15
- print "sending %d packets on their way" % sent_histories
- send_reqs = mpi.RequestList()
- for i in range(sent_histories):
- dest = random.randrange(1, mpi.size)
- send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))
- mpi.wait_all(send_reqs)
- completed_histories = []
- progress_reports = {}
- dead_kids = []
- tgl = TagGroupListener(mpi.world,
- [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])
- def is_complete():
- for i in progress_reports.values():
- if i != sent_histories:
- return False
- return len(dead_kids) == mpi.size-1
- while True:
- status, data = tgl.wait()
- if status.tag == TAG_DATA:
- #print "received completed history %s from %d" % (data, status.source)
- completed_histories.append(data)
- if len(completed_histories) == sent_histories:
- print "all histories received, exiting"
- for rank in range(1, mpi.size):
- mpi.world.send(rank, TAG_TERMINATE, None)
- elif status.tag == TAG_PROGRESS_REPORT:
- progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
- elif status.tag == TAG_DEBUG:
- print "[DBG %d] %s" % (status.source, data)
- elif status.tag == TAG_TERMINATE:
- dead_kids.append(status.source)
- else:
- print "unexpected tag %d from %d" % (status.tag, status.source)
- if is_complete():
- break
- print "OK"
- def comm_rank():
- while True:
- data, status = mpi.world.recv(return_status=True)
- if status.tag == TAG_DATA:
- mpi.world.send(0, TAG_PROGRESS_REPORT, data)
- data.append(mpi.rank)
- if len(data) >= MAX_GENERATIONS:
- dest = 0
- else:
- dest = random.randrange(1, mpi.size)
- mpi.world.send(dest, TAG_DATA, data)
- elif status.tag == TAG_TERMINATE:
- from time import sleep
- mpi.world.send(0, TAG_TERMINATE, 0)
- break
- else:
- print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)
- def main():
- # this program sends around messages consisting of lists of visited nodes
- # randomly. After MAX_GENERATIONS, they are returned to rank 0.
- if mpi.rank == 0:
- rank0()
- else:
- comm_rank()
-
- if __name__ == "__main__":
- main()
|