rsa_context_mp Class Reference

#include <rsa_context_mp.hh>

Inheritance diagram for rsa_context_mp:
Inheritance graph
[legend]
Collaboration diagram for rsa_context_mp:
Collaboration graph
[legend]

Public Member Functions

 rsa_context_mp (int keylen)
 rsa_context_mp (const std::string &filename, const std::string &passwd)
 rsa_context_mp (const char *filename, const char *passwd)
virtual void priv_decrypt (unsigned char *out, int *out_len, const unsigned char *in, int in_len)
virtual void priv_decrypt_batch (unsigned char **out, int *out_len, const unsigned char **in, const int *in_len, int n)
void priv_decrypt_stream (unsigned char **out, int *out_len, const unsigned char **in, const int *in_len, int n, unsigned int stream_id)
bool sync (unsigned int stream_id, bool block=true, bool copy_result=true)
void set_device_context (device_context *dev_ctx)

Detailed Description

class rsa_context

Interface for RSA processing in GPU.


Constructor & Destructor Documentation

rsa_context_mp::rsa_context_mp ( int  keylen  ) 

Constructor. It will randomly generate the RSA key pair with the given key length.

Parameters:
keylen Length of key in bits. Supported length are 512, 1024, 2048, and 4096 bits.

00008         : rsa_context(keylen)
00009 {
00010         gpu_setup();
00011 }

rsa_context_mp::rsa_context_mp ( const std::string &  filename,
const std::string &  passwd 
)

Constructor. It will load key from the file using the given password. Currently supports PEM format only

Parameters:
filename file path that contains the rsa private key
passwd password used to encrypt the private key

00015         : rsa_context(filename, passwd)
00016 {
00017         gpu_setup();
00018 }

rsa_context_mp::rsa_context_mp ( const char *  filename,
const char *  passwd 
)

Constructor. It will load key from the file using the given password. Currently supports PEM format only

Parameters:
filename file path that contains the rsa private key
passwd password used to encrypt the private key

00021         : rsa_context(filename, passwd)
00022 {
00023         gpu_setup();
00024 }


Member Function Documentation

void rsa_context_mp::priv_decrypt ( unsigned char *  out,
int *  out_len,
const unsigned char *  in,
int  in_len 
) [virtual]

Decrypt the data with RSA algorithm using private key. All encryption/decryption methods assume RSA_PKCS1_PADDING

Parameters:
out Buffer for output.
out_len In: allocated buffer space for output result, Out: output size.
in Buffer that stores cipher text.
in_len Length of cipher text

Reimplemented from rsa_context.

00068 {
00069         priv_decrypt_batch(&out, out_len, &in, &in_len, 1);
00070 }

Here is the call graph for this function:

void rsa_context_mp::priv_decrypt_batch ( unsigned char **  out,
int *  out_len,
const unsigned char **  in,
const int *  in_len,
int  n 
) [virtual]

Decrypt the data with RSA algorithm using private key in a batch All encryption/decryption methods assume RSA_PKCS1_PADDING

Parameters:
out Buffers for plain text.
out_len In: allocated buffer space for output results, Out: output sizes.
in Buffers that stores ciphertext.
in_len Length of cipher texts.
n Ciphertexts count.

Reimplemented from rsa_context.

00075 {
00076         // by default, stream is not used
00077         priv_decrypt_stream(out, out_len, in, in_len, n, 0);
00078         sync(0);
00079 }

Here is the call graph for this function:

Here is the caller graph for this function:

void rsa_context_mp::priv_decrypt_stream ( unsigned char **  out,
int *  out_len,
const unsigned char **  in,
const int *  in_len,
int  n,
unsigned int  stream_id 
)

Decrypt the data with RSA algorithm using private key in a batch All encryption/decryption methods assume RSA_PKCS1_PADDING It runs asynchronously. Use sync() for completion check.

Parameters:
out Buffers for plain text.
out_len In: allocated buffer space for output results, Out: output sizes.
in Buffers that stores ciphertext.
in_len Length of cipher texts.
n Ciphertexts count.
stream_id Stream index. 1 <= stream_id <= max_stream

00084 {
00085         assert(is_crt_available());
00086         assert(0 < n && n <= max_batch);
00087         assert(n <= MP_MAX_NUM_PAIRS);
00088         assert(stream_id <= max_stream);
00089         assert(dev_ctx_ != NULL);
00090         assert(dev_ctx_->get_state(stream_id) == READY);
00091         dev_ctx_->set_state(stream_id, WAIT_KERNEL);
00092 
00093         int word_len = (get_key_bits() / 2) / BITS_PER_WORD;
00094         int S = word_len;
00095         int num_blks = ((n + MP_MSGS_PER_BLOCK - 1) / MP_MSGS_PER_BLOCK) * 2;
00096         dev_ctx_->clear_checkbits(stream_id, num_blks);
00097         streams[stream_id].post_launched = false;
00098 
00099         for (int i = 0; i < n; i++) {
00100                 BN_bin2bn(in[i], in_len[i], in_bn_p);
00101                 BN_bin2bn(in[i], in_len[i], in_bn_q);
00102                 assert(in_bn_p != NULL);
00103                 assert(in_bn_q != NULL);
00104 
00105                 //assert(BN_cmp(in_bn_p, rsa->n) < 0);
00106 
00107                 BN_nnmod(in_bn_p, in_bn_p, rsa->p, bn_ctx);     // TODO: test BN_nnmod
00108                 BN_nnmod(in_bn_q, in_bn_q, rsa->q, bn_ctx);
00109 
00110                 mp_bn2mp(streams[stream_id].a + (i * 2 * MAX_S), in_bn_p, word_len);
00111                 mp_bn2mp(streams[stream_id].a + (i * 2 * MAX_S) + MAX_S, in_bn_q, word_len);
00112         }
00113 
00114         //copy in put and execute kernel
00115         mp_modexp_crt(streams[stream_id].a,
00116                         n, word_len,
00117                         streams[stream_id].ret_d, streams[stream_id].a_d,
00118                         sw_d,
00119                         n_d,
00120                         np_d,
00121                         r_sqr_d,
00122                       dev_ctx_->get_stream(stream_id),
00123                       stream_id,
00124                       dev_ctx_->get_dev_checkbits(stream_id));
00125 
00126 
00127         streams[stream_id].n = n;
00128         streams[stream_id].out = out;
00129         streams[stream_id].out_len = out_len;
00130 }

Here is the call graph for this function:

Here is the caller graph for this function:

void rsa_context_mp::set_device_context ( device_context dev_ctx  )  [inline]

Sets the device context for the rsa context. TODO: Move dev_ctx initialization to constructor.

Parameters:
dev_ctx device context.

00117 {dev_ctx_ = dev_ctx;};

bool rsa_context_mp::sync ( unsigned int  stream_id,
bool  block = true,
bool  copy_result = true 
)

Synchronize/query the execution on the stream. This function can be used to check whether the current execution on the stream is finished or also be used to wait until the execution to be finished.

Parameters:
stream Stream index.
block Wait for the execution to finish or not. true by default.
copy_result If false, it will not copy result back to CPU.
Returns:
true if the current operation on the stream is finished otherwise false.

00133 {
00134         assert(stream_id <= max_stream);
00135         int word_len = (get_key_bits() / 2) / BITS_PER_WORD;
00136 
00137         if (dev_ctx_->get_state(stream_id) == READY)
00138                 return true;
00139 
00140         //blocing case
00141         if (block) {
00142                 //wait for previous operation to finish
00143                 dev_ctx_->sync(stream_id, true);
00144                 if (dev_ctx_->get_state(stream_id) == WAIT_KERNEL &&
00145                          streams[stream_id].post_launched == false) {
00146                         //post kernel launch
00147                         int S = word_len;
00148                         int num_blks = ((streams[stream_id].n + MP_MSGS_PER_BLOCK - 1) / MP_MSGS_PER_BLOCK);
00149                         dev_ctx_->clear_checkbits(stream_id, num_blks);
00150                         mp_modexp_crt_post_kernel(streams[stream_id].ret,
00151                                                   streams[stream_id].ret_d,
00152                                                   n_d,
00153                                                   np_d,
00154                                                   r_sqr_d,
00155                                                   iqmp_d,
00156                                                   streams[stream_id].n,
00157                                                   word_len,
00158                                                   block,
00159                                                   dev_ctx_->get_stream(stream_id),
00160                                                   dev_ctx_->get_dev_checkbits(stream_id));
00161                         streams[stream_id].post_launched = true;
00162                         dev_ctx_->sync(stream_id, true);
00163                 }
00164 
00165                 if (dev_ctx_->get_state(stream_id) == WAIT_KERNEL &&
00166                     streams[stream_id].post_launched == true) {
00167                         //copy result
00168                         dev_ctx_->set_state(stream_id, WAIT_COPY);
00169                         cutilSafeCall(cudaMemcpyAsync(streams[stream_id].ret,
00170                                                       streams[stream_id].ret_d,
00171                                                       sizeof(WORD[2][MAX_S]) * streams[stream_id].n,
00172                                                       cudaMemcpyDeviceToHost,
00173                                                       dev_ctx_->get_stream(stream_id)));
00174                         dev_ctx_->sync(stream_id, true);
00175                 }
00176 
00177                 if (dev_ctx_->get_state(stream_id) == WAIT_COPY) {
00178                         dev_ctx_->set_state(stream_id, READY);
00179                 }
00180 
00181                 //move result to out from gathred buffer
00182                 for (int i = 0; i < streams[stream_id].n; i++) {
00183                         int rsa_bytes = get_key_bits() / 8;
00184 
00185                         int ret = RSA_padding_check_PKCS1_type_2(streams[stream_id].out[i],
00186                                                                  streams[stream_id].out_len[i],
00187                                                                  (unsigned char *)(streams[stream_id].ret + (i * 2 * MAX_S)) + 1,
00188                                                                  rsa_bytes - 1,
00189                                                                  rsa_bytes);
00190                         if (ret == -1) {
00191                                 for (int j = 0; j < 2 * word_len * (int)sizeof(WORD); j++)
00192                                         printf("%02x ", *(((unsigned char *)(streams[stream_id].ret + (i * 2 * MAX_S)) + j)));
00193                                 printf("\n");
00194                                 assert(false);
00195                         }
00196                         streams[stream_id].out_len[i] = ret;
00197                 }
00198 
00199                 return true;
00200         }
00201 
00202 
00203         //nonblocking case
00204         if (dev_ctx_->get_state(stream_id) == WAIT_KERNEL) {
00205                 if (!dev_ctx_->sync(stream_id, false))
00206                         return false;
00207                 if (!streams[stream_id].post_launched) {
00208                         //start post kernel execution
00209                         int S = word_len;
00210                         int num_blks = ((streams[stream_id].n + MP_MSGS_PER_BLOCK - 1) / MP_MSGS_PER_BLOCK);
00211 
00212                         streams[stream_id].post_launched = true;
00213                         dev_ctx_->clear_checkbits(stream_id, num_blks);
00214                         mp_modexp_crt_post_kernel(streams[stream_id].ret,
00215                                                   streams[stream_id].ret_d,
00216                                                   n_d,
00217                                                   np_d,
00218                                                   r_sqr_d,
00219                                                   iqmp_d,
00220                                                   streams[stream_id].n,
00221                                                   word_len,
00222                                                   block,
00223                                                   dev_ctx_->get_stream(stream_id),
00224                                                   dev_ctx_->get_dev_checkbits(stream_id));
00225                         streams[stream_id].post_launched = true;
00226                         return false;
00227                 } else {
00228                         //start copying result
00229                         dev_ctx_->set_state(stream_id, WAIT_COPY);
00230                         cutilSafeCall(cudaMemcpyAsync(streams[stream_id].ret,
00231                                                       streams[stream_id].ret_d,
00232                                                       sizeof(WORD[2][MAX_S]) * streams[stream_id].n,
00233                                                       cudaMemcpyDeviceToHost,
00234                                                       dev_ctx_->get_stream(stream_id)));
00235                         return false;
00236                 }
00237 
00238         } else if (dev_ctx_->get_state(stream_id) == WAIT_COPY) {
00239                 if (!dev_ctx_->sync(stream_id, false))
00240                         return false;
00241 
00242                 //move result to out from gathred buffer
00243                 for (int i = 0; i < streams[stream_id].n; i++) {
00244                         int rsa_bytes = get_key_bits() / 8;
00245 
00246                         int ret = RSA_padding_check_PKCS1_type_2(streams[stream_id].out[i],
00247                                                                  streams[stream_id].out_len[i],
00248                                                                  (unsigned char *)(streams[stream_id].ret + (i * 2 * MAX_S)) + 1,
00249                                                                  rsa_bytes - 1,
00250                                                                  rsa_bytes);
00251                         if (ret == -1) {
00252                                 for (int j = 0; j < 2 * word_len * (int)sizeof(WORD); j++)
00253                                         printf("%02x ", *(((unsigned char *)(streams[stream_id].ret + (i * 2 * MAX_S)) + j)));
00254                                 printf("\n");
00255                                 assert(false);
00256                         }
00257                         streams[stream_id].out_len[i] = ret;
00258                 }
00259 
00260                 dev_ctx_->set_state(stream_id, READY);
00261                 return true;
00262         }
00263         return false;
00264 }

Here is the call graph for this function:

Here is the caller graph for this function:

 All Data Structures Functions
Generated on Tue Oct 18 10:20:21 2011 for libgpucrypto by  doxygen 1.6.3