与 C.gather 相关的 CNTK 运行时错误

CNTK Runtime Error related to C.gather

我在尝试测试 CNTK 时遇到错误。

我正在尝试使用 input_variable 作为索引对参数进行切片。使用C.gather进行切片会导致backprop过程内存错误

CPU、GPU、Docker、本地安装等所有cntk2环境都出现错误

错误消息和调用堆栈

RuntimeError: CUBLAS failure 11: CUBLAS_STATUS_MAPPING_ERROR ; GPU=0 ; hostname=.... ; expr=cublasGetMatrix((int) numRows, (int) numCols, sizeof(ElemType), Data(), (int) GetNumRows(), dst, (int) colStride)

[CALL STACK] Microsoft::MSR::CNTK::CudaTimer:: Stop - Microsoft::MSR::CNTK::Matrix:: CopySection - Microsoft::MSR::CNTK::Matrix:: AssignValuesOf - CNTK::NDArrayView:: CopyFrom - CNTK::NDArrayView::NDArrayView
- CNTK::TrainingParameterSchedule:: Serialize - CNTK::DictionaryValue:: Save - CNTK::Trainer:: SummarizeTrainingProgress - PyInit__cntk_py - PyCFunction_Call - PyEval_GetFuncDesc - PyEval_EvalFrameEx - PyEval_GetFuncDesc (x2) - PyEval_EvalFrameEx (x2)

代码

x = input_val[:-2]
p1 = input_val[-2]
p2 = input_val[-1]

activator = relu

W1 = C.Parameter((slices,input_dim,hidden_layers_dim), init=C.glorot_normal(), name='W1')
b1 = C.Parameter((slices,hidden_layers_dim), init=0, name='b1')
W2 = C.Parameter((slices,hidden_layers_dim,hidden_layers_dim), init=C.glorot_normal(), name='W2')
b2 = C.Parameter((slices,hidden_layers_dim), init=0, name='b2')
W3 = C.Parameter((slices,hidden_layers_dim,output_dim), init=C.glorot_normal(), name='W3')
b3 = C.Parameter((slices,output_dim), init=0, name='b3')

W11 = C.gather(W1, p1)
b11 = C.gather(b1, p1)
W1x = C.reshape(W11, (input_dim,hidden_layers_dim))
b1x = C.reshape(b11, (hidden_layers_dim,))

W21 = C.gather(W2, p1)
b21 = C.gather(b2, p1)
W2x = C.reshape(W21, (hidden_layers_dim,hidden_layers_dim))
b2x = C.reshape(b21, (hidden_layers_dim,))

W31 = C.gather(W3, p1)
b31 = C.gather(b3, p1)
W3x = C.reshape(W31, (hidden_layers_dim,output_dim))
b3x = C.reshape(b31, (output_dim,))

x = activator(C.times(x, W1x) + b1x)
x = activator(C.times(x, W2x) + b2x)
x = C.times(x, W3x) + b3x

我无法用最新的母版重现这个。这很可能是在 CNTK 2.1 发布后立即修复的错误。下一个版本(2.2)将在 2017 年 9 月 15 日左右发布。如果问题仍然存在,升级到 2.2 后,打开 github issue 可能是解决此问题的正确方法。